rebase default sample_rate to 24khz for runtime

This commit is contained in:
SWivid
2025-06-04 11:22:31 +08:00
parent 7e37bc5d9a
commit 6fbe7592f5
3 changed files with 10 additions and 10 deletions

View File

@@ -220,8 +220,8 @@ def get_args():
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
@@ -244,7 +244,7 @@ async def send(
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
save_sample_rate: int = 24000,
):
total_duration = 0.0
latency_data = []
@@ -254,7 +254,7 @@ async def send(
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
@@ -417,7 +417,7 @@ async def main():
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
save_sample_rate=24000,
)
)
tasks.append(task)

View File

@@ -82,7 +82,7 @@ def prepare_request(
samples,
reference_text,
target_text,
sample_rate=16000,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
@@ -106,8 +106,8 @@ def prepare_request(
return data
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
@@ -129,7 +129,7 @@ if __name__ == "__main__":
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
assert sr == 24000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)

View File

@@ -33,7 +33,7 @@ parameters [
},
{
key: "reference_audio_sample_rate",
value: {string_value:"16000"}
value: {string_value:"24000"}
},
{
key: "vocoder",