diff --git a/src/f5_tts/socket.py b/src/f5_tts/socket.py index 183f24e..22f7e7a 100644 --- a/src/f5_tts/socket.py +++ b/src/f5_tts/socket.py @@ -19,10 +19,14 @@ class TTSStreamingProcessor: # Load the model using the provided checkpoint and vocab files self.model = load_model( - DiT, - dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), - ckpt_file, - vocab_file, + model_cls=DiT, + model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), + ckpt_path=ckpt_file, + mel_spec_type="vocos", # or "bigvgan" depending on vocoder + vocab_file=vocab_file, + ode_method="euler", + use_ema=True, + device=self.device ).to(self.device, dtype=dtype) # Load the vocoder