formatting, sorting

This commit is contained in:
SWivid
2025-05-05 01:41:28 +08:00
parent b4efcd836a
commit 6d1a1e886a
40 changed files with 167 additions and 155 deletions

View File

@@ -3,11 +3,14 @@ repos:
# Ruff version.
rev: v0.11.2
hooks:
# Run the linter.
- id: ruff
name: ruff linter
args: [--fix]
# Run the formatter.
- id: ruff-format
name: ruff formatter
- id: ruff
name: ruff sorter
args: [--select, I, --fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:

View File

@@ -6,5 +6,5 @@ target-version = "py310"
dummy-variable-rgx = "^_.*$"
[lint.isort]
force-single-line = true
force-single-line = false
lines-after-imports = 2

View File

@@ -9,13 +9,13 @@ from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
transcribe,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
transcribe,
)
from f5_tts.model.utils import seed_everything

View File

@@ -4,6 +4,7 @@
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

View File

@@ -1,6 +1,7 @@
import os
import sys
sys.path.append(os.getcwd())
import argparse
@@ -23,6 +24,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"

View File

@@ -5,17 +5,16 @@ import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import (
get_librispeech_test,
run_asr_wer,
run_sim,
)
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))

View File

@@ -5,17 +5,16 @@ import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import (
get_seed_tts_test,
run_asr_wer,
run_sim,
)
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))

View File

@@ -14,20 +14,20 @@ from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
cross_fade_duration,
device,
fix_duration,
infer_process,
load_model,
load_vocoder,
mel_spec_type,
nfe_step,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
speed,
sway_sampling_coef,
target_rms,
)

View File

@@ -18,6 +18,7 @@ import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
@@ -33,15 +34,15 @@ def gpu_decorator(func):
return func
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
)
from f5_tts.model import DiT, UNetT
DEFAULT_TTS_MODEL = "F5-TTS_v1"

View File

@@ -1,5 +1,6 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
from importlib.resources import files
@@ -7,14 +8,15 @@ from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from cached_path import cached_path
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = (
"cuda"
if torch.cuda.is_available()

View File

@@ -4,6 +4,7 @@ import os
import sys
from concurrent.futures import ThreadPoolExecutor
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
@@ -14,6 +15,7 @@ from importlib.resources import files
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
@@ -27,10 +29,8 @@ from transformers import pipeline
from vocos import Vocos
from f5_tts.model import CFM
from f5_tts.model.utils import (
get_tokenizer,
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
_ref_audio_cache = {}

View File

@@ -1,9 +1,7 @@
from f5_tts.model.cfm import CFM
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.backbones.dit import DiT
from f5_tts.model.backbones.mmdit import MMDiT
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.cfm import CFM
from f5_tts.model.trainer import Trainer

View File

@@ -10,19 +10,18 @@ d - dimension
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNorm_Final,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNorm_Final,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)

View File

@@ -11,16 +11,15 @@ from __future__ import annotations
import torch
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNorm_Final,
ConvPositionEmbedding,
MMDiTBlock,
AdaLayerNorm_Final,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)

View File

@@ -8,24 +8,24 @@ d - dimension
"""
from __future__ import annotations
from typing import Literal
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,
Attention,
AttnProcessor,
ConvNeXtV2Block,
ConvPositionEmbedding,
FeedForward,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)

View File

@@ -19,6 +19,7 @@ from f5_tts.model import CFM
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model.utils import default, exists
# trainer

View File

@@ -5,11 +5,10 @@ import random
from collections import defaultdict
from importlib.resources import files
import torch
from torch.nn.utils.rnn import pad_sequence
import jieba
from pypinyin import lazy_pinyin, Style
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
# seed everything

View File

@@ -30,26 +30,27 @@ import argparse
import json
import os
import time
from typing import List, Dict, Union
from typing import Dict, List, Union
import datasets
import jieba
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torchaudio
import jieba
from pypinyin import Style, lazy_pinyin
from datasets import load_dataset
import datasets
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
from f5_tts_trtllm import F5TTS
import tensorrt as trt
from tensorrt_llm.runtime.session import Session, TensorInfo
from tensorrt_llm.logger import logger
from tensorrt_llm._utils import trt_dtype_to_torch
torch.manual_seed(0)
@@ -381,8 +382,8 @@ def main():
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.model import DiT
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
F5TTS_model_cfg = dict(
dim=1024,

View File

@@ -44,7 +44,6 @@ python3 client_grpc.py \
import argparse
import asyncio
import json
import os
import time
import types

View File

@@ -23,10 +23,11 @@
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import numpy as np
import requests
import soundfile as sf
import numpy as np
import argparse
def get_args():

View File

@@ -1,18 +1,17 @@
import tensorrt as trt
import os
import math
import os
import time
from typing import List, Optional
from functools import wraps
from typing import List, Optional
import tensorrt as trt
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):

View File

@@ -24,16 +24,17 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.dlpack import from_dlpack, to_dlpack
import torchaudio
import jieba
import triton_python_backend_utils as pb_utils
from pypinyin import Style, lazy_pinyin
import os
import jieba
import torch
import torch.nn.functional as F
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
from torch.utils.dlpack import from_dlpack, to_dlpack
def get_tokenizer(vocab_file_path: str):

View File

@@ -34,6 +34,7 @@ from .deepseek_v2.model import DeepseekV2ForCausalLM
from .dit.model import DiT
from .eagle.model import EagleForCausalLM
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .f5tts.model import F5TTS
from .falcon.config import FalconConfig
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
@@ -54,12 +55,12 @@ from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodi
from .mpt.model import MPTForCausalLM, MPTModel
from .nemotron_nas.model import DeciLMForCausalLM
from .opt.model import OPTForCausalLM, OPTModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .phi.model import PhiForCausalLM, PhiModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .qwen.model import QWenForCausalLM
from .recurrentgemma.model import RecurrentGemmaForCausalLM
from .redrafter.model import ReDrafterForCausalLM
from .f5tts.model import F5TTS
__all__ = [
"BertModel",

View File

@@ -1,23 +1,20 @@
from __future__ import annotations
import sys
import os
import sys
from collections import OrderedDict
import tensorrt as trt
from collections import OrderedDict
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt
from ...functional import Tensor, concat
from ...layers import Linear
from ...module import Module, ModuleList
from ...plugin import current_all_reduce_helper
from ..modeling_utils import PretrainedConfig, PretrainedModel
from ...functional import Tensor, concat
from ...module import Module, ModuleList
from tensorrt_llm._common import default_net
from ...layers import Linear
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
from .modules import (
TimestepEmbedding,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
)
current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)

View File

@@ -3,33 +3,35 @@ from __future__ import annotations
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import numpy as np
from tensorrt_llm._common import default_net
from ..._utils import trt_dtype_to_np, str_dtype_to_trt
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
from ...functional import (
Tensor,
bert_attention,
cast,
chunk,
concat,
constant,
expand,
expand_dims,
expand_dims_like,
expand_mask,
gelu,
matmul,
permute,
shape,
silu,
slice,
permute,
expand_mask,
expand_dims_like,
unsqueeze,
matmul,
softmax,
squeeze,
cast,
gelu,
unsqueeze,
view,
)
from ...functional import expand_dims, view, bert_attention
from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
from ...module import Module

View File

@@ -40,6 +40,7 @@ import torch as th
import torch.nn.functional as F
from scipy.signal import check_COLA, get_window
support_clp_op = None
if th.__version__ >= "1.7.0":
from torch.fft import rfft as fft

View File

@@ -8,7 +8,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import safetensors.torch
import torch
from tensorrt_llm import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split, split_matrix_tp

View File

@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from conv_stft import STFT
from huggingface_hub import hf_hub_download
from vocos import Vocos
import argparse
opset_version = 17

View File

@@ -1,12 +1,13 @@
import sys
import os
import sys
sys.path.append(os.getcwd())
from f5_tts.model import CFM, DiT
import torch
import thop
import torch
from f5_tts.model import CFM, DiT
""" ~155M """

View File

@@ -1,10 +1,12 @@
import socket
import asyncio
import pyaudio
import numpy as np
import logging
import socket
import time
import numpy as np
import pyaudio
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,6 @@
import argparse
import gc
import logging
import numpy as np
import queue
import socket
import struct
@@ -10,6 +9,7 @@ import traceback
import wave
from importlib.resources import files
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
@@ -18,12 +18,13 @@ from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
chunk_text,
preprocess_ref_audio_text,
load_vocoder,
load_model,
infer_batch_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

View File

@@ -1,12 +1,13 @@
import os
import sys
import signal
import subprocess # For invoking ffprobe
import shutil
import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess # For invoking ffprobe
import sys
from contextlib import contextmanager
sys.path.append(os.getcwd())
import argparse
@@ -16,12 +17,10 @@ from importlib.resources import files
from pathlib import Path
import torchaudio
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import (
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")

View File

@@ -7,20 +7,18 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import (
repetition_found,
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
out_zh = {

View File

@@ -1,17 +1,17 @@
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
# prepares Emilia dataset with the new format w/ Emilia-YODAS
import os
import json
import os
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import repetition_found
from f5_tts.model.utils import (
repetition_found,
)
# Define filters for exclusion
out_en = set()

View File

@@ -1,15 +1,17 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def deal_with_audio_dir(audio_dir):

View File

@@ -1,14 +1,16 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def main():

View File

@@ -4,15 +4,16 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from tqdm import tqdm
import torchaudio
from datasets import Dataset
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin

View File

@@ -5,9 +5,9 @@ from importlib.resources import files
from cached_path import cached_path
from f5_tts.model import CFM, UNetT, DiT, Trainer
from f5_tts.model.utils import get_tokenizer
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #

View File

@@ -1,14 +1,12 @@
import gc
import json
import numpy as np
import os
import platform
import psutil
import queue
import random
import re
import signal
import shutil
import signal
import subprocess
import sys
import tempfile
@@ -16,21 +14,23 @@ import threading
import time
from glob import glob
from importlib.resources import files
from scipy.io import wavfile
import click
import gradio as gr
import librosa
import numpy as np
import psutil
import torch
import torchaudio
from cached_path import cached_path
from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import load_file, save_file
from scipy.io import wavfile
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.infer.utils_infer import transcribe
from f5_tts.model.utils import convert_char_to_pinyin
training_process = None

View File

@@ -10,6 +10,7 @@ from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)