cache all error messages and add support to fix mac gpu issues

This commit is contained in:
unknown
2024-10-18 00:42:16 +03:00
parent 549ee89b74
commit 34ccbcb451

View File

@@ -9,19 +9,16 @@ from glob import glob
import librosa
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm
import shutil
import time
import json
from datasets import Dataset
from model.utils import convert_char_to_pinyin
import signal
import psutil
import platform
import subprocess
from datasets.arrow_writer import ArrowWriter
from datasets import load_dataset, load_from_disk
import json
@@ -265,8 +262,20 @@ def start_training(dataset_name="",
finetune=True,
):
global training_process
path_project = os.path.join(path_data, dataset_name + "_pinyin")
if os.path.isdir(path_project)==False:
yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
return
file_raw = os.path.join(path_project,"raw.arrow")
if os.path.isfile(file_raw)==False:
yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
return
# Check if a training process is already running
if training_process is not None:
return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
@@ -346,6 +355,8 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
path_project_wavs = os.path.join(path_project,"wavs")
file_metadata = os.path.join(path_project,"metadata.csv")
if audio_files is None:return "You need to load an audio file."
if os.path.isdir(path_project_wavs):
shutil.rmtree(path_project_wavs)
@@ -356,16 +367,17 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
if user:
file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
if file_audios==[]:return "No audio file was found in the dataset."
else:
file_audios = audio_files
print([file_audios])
alpha = 0.5
_max = 1.0
slicer = Slicer(24000)
num = 0
error_num = 0
data=""
for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
@@ -381,18 +393,26 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
if(tmp_max>1):chunk/=tmp_max
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
try:
text=transcribe(file_segment,language)
text = text.lower().strip().replace('"',"")
text=transcribe(file_segment,language)
text = text.lower().strip().replace('"',"")
data+= f"{name_segment}|{text}\n"
data+= f"{name_segment}|{text}\n"
num+=1
except:
error_num +=1
num+=1
with open(file_metadata,"w",encoding="utf-8") as f:
f.write(data)
return f"transcribe complete samples : {num} in path {path_project_wavs}"
if error_num!=[]:
error_text=f"\nerror files : {error_num}"
else:
error_text=""
return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
def format_seconds_to_hms(seconds):
hours = int(seconds / 3600)
@@ -408,6 +428,8 @@ def create_metadata(name_project,progress=gr.Progress()):
file_raw = os.path.join(path_project,"raw.arrow")
file_duration = os.path.join(path_project,"duration.json")
file_vocab = os.path.join(path_project,"vocab.txt")
if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
with open(file_metadata,"r",encoding="utf-8") as f:
data=f.read()
@@ -419,11 +441,18 @@ def create_metadata(name_project,progress=gr.Progress()):
count=data.split("\n")
lenght=0
result=[]
error_files=[]
for line in progress.tqdm(data.split("\n"),total=count):
sp_line=line.split("|")
if len(sp_line)!=2:continue
name_audio,text = sp_line[:2]
name_audio,text = sp_line[:2]
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
if os.path.isfile(file_audio)==False:
error_files.append(file_audio)
continue
duraction = get_audio_duration(file_audio)
if duraction<2 and duraction>15:continue
if len(text)<4:continue
@@ -439,6 +468,10 @@ def create_metadata(name_project,progress=gr.Progress()):
lenght+=duraction
if duration_list==[]:
error_files_text="\n".join(error_files)
return f"Error: No audio files found in the specified path : \n{error_files_text}"
min_second = round(min(duration_list),2)
max_second = round(max(duration_list),2)
@@ -450,9 +483,15 @@ def create_metadata(name_project,progress=gr.Progress()):
json.dump({"duration": duration_list}, f, ensure_ascii=False)
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
shutil.copy2(file_vocab_finetune, file_vocab)
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n"
if error_files!=[]:
error_text="error files\n" + "\n".join(error_files)
else:
error_text=""
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
def check_user(value):
return gr.update(visible=not value),gr.update(visible=value)
@@ -468,13 +507,16 @@ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_w
duration_list = data['duration']
samples = len(duration_list)
gpu_properties = torch.cuda.get_device_properties(0)
total_memory = gpu_properties.total_memory / (1024 ** 3)
if torch.cuda.is_available():
gpu_properties = torch.cuda.get_device_properties(0)
total_memory = gpu_properties.total_memory / (1024 ** 3)
elif torch.backends.mps.is_available():
total_memory = psutil.virtual_memory().available / (1024 ** 3)
if batch_size_type=="frame":
batch = int(total_memory * 0.5)
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
batch_size_per_gpu = int(36800 / batch )
batch_size_per_gpu = int(38400 / batch )
else:
batch_size_per_gpu = int(total_memory / 8)
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
@@ -509,13 +551,12 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
if ema_model_state_dict is not None:
new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
torch.save(new_checkpoint, new_checkpoint_path)
print(f"New checkpoint saved at: {new_checkpoint_path}")
return f"New checkpoint saved at: {new_checkpoint_path}"
else:
print("No 'ema_model_state_dict' found in the checkpoint.")
return "No 'ema_model_state_dict' found in the checkpoint."
except Exception as e:
print(f"An error occurred: {e}")
return f"An error occurred: {e}"
def vocab_check(project_name):
name_project = project_name + "_pinyin"
@@ -524,12 +565,17 @@ def vocab_check(project_name):
file_metadata = os.path.join(path_project, "metadata.csv")
file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
if os.path.isfile(file_vocab)==False:
return f"the file {file_vocab} not found !"
with open(file_vocab,"r",encoding="utf-8") as f:
data=f.read()
vocab = data.split("\n")
if os.path.isfile(file_metadata)==False:
return f"the file {file_metadata} not found !"
with open(file_metadata,"r",encoding="utf-8") as f:
data=f.read()
@@ -548,6 +594,7 @@ def vocab_check(project_name):
if miss_symbols==[]:info ="You can train using your language !"
else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
return info
@@ -652,8 +699,9 @@ with gr.Blocks() as app:
with gr.TabItem("reduse checkpoint"):
txt_path_checkpoint = gr.Text(label="path checkpoint :")
txt_path_checkpoint_small = gr.Text(label="path output :")
txt_info_reduse = gr.Text(label="info",value="")
reduse_button = gr.Button("reduse")
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small])
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
with gr.TabItem("vocab check experiment"):
check_button = gr.Button("check vocab")
@@ -680,10 +728,4 @@ def main(port, host, share, api):
)
if __name__ == "__main__":
name="my_speak"
#create_data_project(name)
#transcribe_all(name)
#create_metadata(name)
main()