mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-30 06:31:54 -08:00
40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
import sys
|
|
import os
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from model import M2_TTS, DiT
|
|
|
|
import torch
|
|
import thop
|
|
|
|
|
|
""" ~155M """
|
|
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
|
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
|
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
|
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
|
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
|
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
|
|
|
""" ~335M """
|
|
# FLOPs: 622.1 G, Params: 333.2 M
|
|
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
|
# FLOPs: 363.4 G, Params: 335.8 M
|
|
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
|
|
|
|
|
model = M2_TTS(transformer=transformer)
|
|
target_sample_rate = 24000
|
|
n_mel_channels = 100
|
|
hop_length = 256
|
|
duration = 20
|
|
frame_length = int(duration * target_sample_rate / hop_length)
|
|
text_length = 150
|
|
|
|
flops, params = thop.profile(
|
|
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
|
|
)
|
|
print(f"FLOPs: {flops / 1e9} G")
|
|
print(f"Params: {params / 1e6} M")
|