diff --git a/src/f5_tts/runtime/triton_trtllm/client_grpc.py b/src/f5_tts/runtime/triton_trtllm/client_grpc.py index 5a4adc5..1eb9b5b 100644 --- a/src/f5_tts/runtime/triton_trtllm/client_grpc.py +++ b/src/f5_tts/runtime/triton_trtllm/client_grpc.py @@ -220,8 +220,8 @@ def get_args(): return parser.parse_args() -def load_audio(wav_path, target_sample_rate=16000): - assert target_sample_rate == 16000, "hard coding in server" +def load_audio(wav_path, target_sample_rate=24000): + assert target_sample_rate == 24000, "hard coding in server" if isinstance(wav_path, dict): waveform = wav_path["array"] sample_rate = wav_path["sampling_rate"] @@ -244,7 +244,7 @@ async def send( model_name: str, padding_duration: int = None, audio_save_dir: str = "./", - save_sample_rate: int = 16000, + save_sample_rate: int = 24000, ): total_duration = 0.0 latency_data = [] @@ -254,7 +254,7 @@ async def send( for i, item in enumerate(manifest_item_list): if i % log_interval == 0: print(f"{name}: {i}/{len(manifest_item_list)}") - waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) + waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000) duration = len(waveform) / sample_rate lengths = np.array([[len(waveform)]], dtype=np.int32) @@ -417,7 +417,7 @@ async def main(): model_name=args.model_name, audio_save_dir=args.log_dir, padding_duration=1, - save_sample_rate=24000 if args.model_name == "f5_tts" else 16000, + save_sample_rate=24000, ) ) tasks.append(task) diff --git a/src/f5_tts/runtime/triton_trtllm/client_http.py b/src/f5_tts/runtime/triton_trtllm/client_http.py index 2f11f02..804ba5c 100644 --- a/src/f5_tts/runtime/triton_trtllm/client_http.py +++ b/src/f5_tts/runtime/triton_trtllm/client_http.py @@ -82,7 +82,7 @@ def prepare_request( samples, reference_text, target_text, - sample_rate=16000, + sample_rate=24000, audio_save_dir: str = "./", ): assert len(samples.shape) == 1, "samples should be 1D" @@ -106,8 +106,8 @@ def prepare_request( return data -def load_audio(wav_path, target_sample_rate=16000): - assert target_sample_rate == 16000, "hard coding in server" +def load_audio(wav_path, target_sample_rate=24000): + assert target_sample_rate == 24000, "hard coding in server" if isinstance(wav_path, dict): samples = wav_path["array"] sample_rate = wav_path["sampling_rate"] @@ -129,7 +129,7 @@ if __name__ == "__main__": url = f"{server_url}/v2/models/{args.model_name}/infer" samples, sr = load_audio(args.reference_audio) - assert sr == 16000, "sample rate hardcoded in server" + assert sr == 24000, "sample rate hardcoded in server" samples = np.array(samples, dtype=np.float32) data = prepare_request(samples, args.reference_text, args.target_text) diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt index 4663f7c..07f0c78 100644 --- a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt @@ -33,7 +33,7 @@ parameters [ }, { key: "reference_audio_sample_rate", - value: {string_value:"16000"} + value: {string_value:"24000"} }, { key: "vocoder",