Merge branch 'SWivid:main' into main

This commit is contained in:
Rino
2024-11-03 11:40:25 +07:00
committed by GitHub
2 changed files with 15 additions and 5 deletions

View File

@@ -119,12 +119,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
asr_pipe = None
def initialize_asr_pipeline(device=device):
def initialize_asr_pipeline(device=device, dtype=None):
if dtype is None:
dtype = (
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
torch_dtype=dtype,
device=device,
)

View File

@@ -325,7 +325,9 @@ class Trainer:
if self.log_samples and self.accelerator.is_local_main_process:
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
torchaudio.save(
f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
)
with torch.inference_mode():
generated, _ = self.accelerator.unwrap_model(self.model).sample(
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
@@ -336,8 +338,12 @@ class Trainer:
sway_sampling_coef=sway_sampling_coef,
)
generated = generated.to(torch.float32)
gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
gen_audio = vocoder.decode(
generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
)
torchaudio.save(
f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
)
if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)