mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-10 12:14:56 -08:00
cache all error messages and add support to fix mac gpu issues
This commit is contained in:
@@ -9,19 +9,16 @@ from glob import glob
|
||||
import librosa
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
from tqdm import tqdm
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import json
|
||||
from datasets import Dataset
|
||||
from model.utils import convert_char_to_pinyin
|
||||
import signal
|
||||
import psutil
|
||||
import platform
|
||||
import subprocess
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
import json
|
||||
|
||||
@@ -265,8 +262,20 @@ def start_training(dataset_name="",
|
||||
finetune=True,
|
||||
):
|
||||
|
||||
|
||||
global training_process
|
||||
|
||||
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
||||
|
||||
if os.path.isdir(path_project)==False:
|
||||
yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
|
||||
return
|
||||
|
||||
file_raw = os.path.join(path_project,"raw.arrow")
|
||||
if os.path.isfile(file_raw)==False:
|
||||
yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
|
||||
return
|
||||
|
||||
# Check if a training process is already running
|
||||
if training_process is not None:
|
||||
return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
|
||||
@@ -346,6 +355,8 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
|
||||
path_project_wavs = os.path.join(path_project,"wavs")
|
||||
file_metadata = os.path.join(path_project,"metadata.csv")
|
||||
|
||||
if audio_files is None:return "You need to load an audio file."
|
||||
|
||||
if os.path.isdir(path_project_wavs):
|
||||
shutil.rmtree(path_project_wavs)
|
||||
|
||||
@@ -356,16 +367,17 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
|
||||
|
||||
if user:
|
||||
file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
|
||||
if file_audios==[]:return "No audio file was found in the dataset."
|
||||
else:
|
||||
file_audios = audio_files
|
||||
|
||||
print([file_audios])
|
||||
|
||||
|
||||
alpha = 0.5
|
||||
_max = 1.0
|
||||
slicer = Slicer(24000)
|
||||
|
||||
num = 0
|
||||
error_num = 0
|
||||
data=""
|
||||
for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
|
||||
|
||||
@@ -381,18 +393,26 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
|
||||
if(tmp_max>1):chunk/=tmp_max
|
||||
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
||||
wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
|
||||
|
||||
try:
|
||||
text=transcribe(file_segment,language)
|
||||
text = text.lower().strip().replace('"',"")
|
||||
|
||||
text=transcribe(file_segment,language)
|
||||
text = text.lower().strip().replace('"',"")
|
||||
data+= f"{name_segment}|{text}\n"
|
||||
|
||||
data+= f"{name_segment}|{text}\n"
|
||||
num+=1
|
||||
except:
|
||||
error_num +=1
|
||||
|
||||
num+=1
|
||||
|
||||
with open(file_metadata,"w",encoding="utf-8") as f:
|
||||
f.write(data)
|
||||
|
||||
return f"transcribe complete samples : {num} in path {path_project_wavs}"
|
||||
|
||||
if error_num!=[]:
|
||||
error_text=f"\nerror files : {error_num}"
|
||||
else:
|
||||
error_text=""
|
||||
|
||||
return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
|
||||
|
||||
def format_seconds_to_hms(seconds):
|
||||
hours = int(seconds / 3600)
|
||||
@@ -408,6 +428,8 @@ def create_metadata(name_project,progress=gr.Progress()):
|
||||
file_raw = os.path.join(path_project,"raw.arrow")
|
||||
file_duration = os.path.join(path_project,"duration.json")
|
||||
file_vocab = os.path.join(path_project,"vocab.txt")
|
||||
|
||||
if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
|
||||
|
||||
with open(file_metadata,"r",encoding="utf-8") as f:
|
||||
data=f.read()
|
||||
@@ -419,11 +441,18 @@ def create_metadata(name_project,progress=gr.Progress()):
|
||||
count=data.split("\n")
|
||||
lenght=0
|
||||
result=[]
|
||||
error_files=[]
|
||||
for line in progress.tqdm(data.split("\n"),total=count):
|
||||
sp_line=line.split("|")
|
||||
if len(sp_line)!=2:continue
|
||||
name_audio,text = sp_line[:2]
|
||||
name_audio,text = sp_line[:2]
|
||||
|
||||
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
|
||||
|
||||
if os.path.isfile(file_audio)==False:
|
||||
error_files.append(file_audio)
|
||||
continue
|
||||
|
||||
duraction = get_audio_duration(file_audio)
|
||||
if duraction<2 and duraction>15:continue
|
||||
if len(text)<4:continue
|
||||
@@ -439,6 +468,10 @@ def create_metadata(name_project,progress=gr.Progress()):
|
||||
|
||||
lenght+=duraction
|
||||
|
||||
if duration_list==[]:
|
||||
error_files_text="\n".join(error_files)
|
||||
return f"Error: No audio files found in the specified path : \n{error_files_text}"
|
||||
|
||||
min_second = round(min(duration_list),2)
|
||||
max_second = round(max(duration_list),2)
|
||||
|
||||
@@ -450,9 +483,15 @@ def create_metadata(name_project,progress=gr.Progress()):
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
||||
if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
|
||||
shutil.copy2(file_vocab_finetune, file_vocab)
|
||||
|
||||
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n"
|
||||
|
||||
if error_files!=[]:
|
||||
error_text="error files\n" + "\n".join(error_files)
|
||||
else:
|
||||
error_text=""
|
||||
|
||||
return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
|
||||
|
||||
def check_user(value):
|
||||
return gr.update(visible=not value),gr.update(visible=value)
|
||||
@@ -468,13 +507,16 @@ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_w
|
||||
duration_list = data['duration']
|
||||
samples = len(duration_list)
|
||||
|
||||
gpu_properties = torch.cuda.get_device_properties(0)
|
||||
total_memory = gpu_properties.total_memory / (1024 ** 3)
|
||||
if torch.cuda.is_available():
|
||||
gpu_properties = torch.cuda.get_device_properties(0)
|
||||
total_memory = gpu_properties.total_memory / (1024 ** 3)
|
||||
elif torch.backends.mps.is_available():
|
||||
total_memory = psutil.virtual_memory().available / (1024 ** 3)
|
||||
|
||||
if batch_size_type=="frame":
|
||||
batch = int(total_memory * 0.5)
|
||||
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
|
||||
batch_size_per_gpu = int(36800 / batch )
|
||||
batch_size_per_gpu = int(38400 / batch )
|
||||
else:
|
||||
batch_size_per_gpu = int(total_memory / 8)
|
||||
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
|
||||
@@ -509,13 +551,12 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
|
||||
if ema_model_state_dict is not None:
|
||||
new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
|
||||
torch.save(new_checkpoint, new_checkpoint_path)
|
||||
print(f"New checkpoint saved at: {new_checkpoint_path}")
|
||||
return f"New checkpoint saved at: {new_checkpoint_path}"
|
||||
else:
|
||||
print("No 'ema_model_state_dict' found in the checkpoint.")
|
||||
return "No 'ema_model_state_dict' found in the checkpoint."
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
return f"An error occurred: {e}"
|
||||
|
||||
def vocab_check(project_name):
|
||||
name_project = project_name + "_pinyin"
|
||||
@@ -524,12 +565,17 @@ def vocab_check(project_name):
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
|
||||
file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
|
||||
if os.path.isfile(file_vocab)==False:
|
||||
return f"the file {file_vocab} not found !"
|
||||
|
||||
with open(file_vocab,"r",encoding="utf-8") as f:
|
||||
data=f.read()
|
||||
|
||||
vocab = data.split("\n")
|
||||
|
||||
if os.path.isfile(file_metadata)==False:
|
||||
return f"the file {file_metadata} not found !"
|
||||
|
||||
with open(file_metadata,"r",encoding="utf-8") as f:
|
||||
data=f.read()
|
||||
|
||||
@@ -548,6 +594,7 @@ def vocab_check(project_name):
|
||||
|
||||
if miss_symbols==[]:info ="You can train using your language !"
|
||||
else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
@@ -652,8 +699,9 @@ with gr.Blocks() as app:
|
||||
with gr.TabItem("reduse checkpoint"):
|
||||
txt_path_checkpoint = gr.Text(label="path checkpoint :")
|
||||
txt_path_checkpoint_small = gr.Text(label="path output :")
|
||||
txt_info_reduse = gr.Text(label="info",value="")
|
||||
reduse_button = gr.Button("reduse")
|
||||
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small])
|
||||
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
|
||||
|
||||
with gr.TabItem("vocab check experiment"):
|
||||
check_button = gr.Button("check vocab")
|
||||
@@ -680,10 +728,4 @@ def main(port, host, share, api):
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
name="my_speak"
|
||||
|
||||
#create_data_project(name)
|
||||
#transcribe_all(name)
|
||||
#create_metadata(name)
|
||||
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user