mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-04-28 08:43:06 -07:00
v1.1.20: refactor cache handling in DiT, MMDiT, and UNetT classes (lazyinit), to fix training bug (EMA deepcopy failure)
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "f5-tts"
|
name = "f5-tts"
|
||||||
version = "1.1.19"
|
version = "1.1.20"
|
||||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = {text = "MIT License"}
|
license = {text = "MIT License"}
|
||||||
|
|||||||
@@ -202,7 +202,6 @@ class DiT(nn.Module):
|
|||||||
average_upsampling=text_embedding_average_upsampling,
|
average_upsampling=text_embedding_average_upsampling,
|
||||||
conv_layers=conv_layers,
|
conv_layers=conv_layers,
|
||||||
)
|
)
|
||||||
self._cache_local = threading.local() # thread-local storage for cache
|
|
||||||
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||||
|
|
||||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||||
@@ -235,21 +234,32 @@ class DiT(nn.Module):
|
|||||||
|
|
||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
|
|
||||||
|
# `_cache_local` is lazily initialized on first inference-time cache write so that
|
||||||
|
# training models (which never touch the cache) stay deepcopy-friendly for EMA.
|
||||||
|
def _get_cache_local(self):
|
||||||
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
if cache is None:
|
||||||
|
cache = threading.local()
|
||||||
|
self.__dict__["_cache_local"] = cache
|
||||||
|
return cache
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text_cond(self):
|
def text_cond(self):
|
||||||
return getattr(self._cache_local, "text_cond", None)
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
return getattr(cache, "text_cond", None) if cache is not None else None
|
||||||
|
|
||||||
@text_cond.setter
|
@text_cond.setter
|
||||||
def text_cond(self, value):
|
def text_cond(self, value):
|
||||||
self._cache_local.text_cond = value
|
self._get_cache_local().text_cond = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text_uncond(self):
|
def text_uncond(self):
|
||||||
return getattr(self._cache_local, "text_uncond", None)
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
return getattr(cache, "text_uncond", None) if cache is not None else None
|
||||||
|
|
||||||
@text_uncond.setter
|
@text_uncond.setter
|
||||||
def text_uncond(self, value):
|
def text_uncond(self, value):
|
||||||
self._cache_local.text_uncond = value
|
self._get_cache_local().text_uncond = value
|
||||||
|
|
||||||
def initialize_weights(self):
|
def initialize_weights(self):
|
||||||
# Zero-out AdaLN layers in DiT blocks:
|
# Zero-out AdaLN layers in DiT blocks:
|
||||||
|
|||||||
@@ -106,7 +106,6 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
self.time_embed = TimestepEmbedding(dim)
|
self.time_embed = TimestepEmbedding(dim)
|
||||||
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
|
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
|
||||||
self._cache_local = threading.local() # thread-local storage for cache
|
|
||||||
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
||||||
|
|
||||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||||
@@ -137,21 +136,32 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
|
|
||||||
|
# `_cache_local` is lazily initialized on first inference-time cache write so that
|
||||||
|
# training models (which never touch the cache) stay deepcopy-friendly for EMA.
|
||||||
|
def _get_cache_local(self):
|
||||||
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
if cache is None:
|
||||||
|
cache = threading.local()
|
||||||
|
self.__dict__["_cache_local"] = cache
|
||||||
|
return cache
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text_cond(self):
|
def text_cond(self):
|
||||||
return getattr(self._cache_local, "text_cond", None)
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
return getattr(cache, "text_cond", None) if cache is not None else None
|
||||||
|
|
||||||
@text_cond.setter
|
@text_cond.setter
|
||||||
def text_cond(self, value):
|
def text_cond(self, value):
|
||||||
self._cache_local.text_cond = value
|
self._get_cache_local().text_cond = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text_uncond(self):
|
def text_uncond(self):
|
||||||
return getattr(self._cache_local, "text_uncond", None)
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
return getattr(cache, "text_uncond", None) if cache is not None else None
|
||||||
|
|
||||||
@text_uncond.setter
|
@text_uncond.setter
|
||||||
def text_uncond(self, value):
|
def text_uncond(self, value):
|
||||||
self._cache_local.text_uncond = value
|
self._get_cache_local().text_uncond = value
|
||||||
|
|
||||||
def initialize_weights(self):
|
def initialize_weights(self):
|
||||||
# Zero-out AdaLN layers in MMDiT blocks:
|
# Zero-out AdaLN layers in MMDiT blocks:
|
||||||
|
|||||||
@@ -135,7 +135,6 @@ class UNetT(nn.Module):
|
|||||||
self.text_embed = TextEmbedding(
|
self.text_embed = TextEmbedding(
|
||||||
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
||||||
)
|
)
|
||||||
self._cache_local = threading.local() # thread-local storage for cache
|
|
||||||
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||||
|
|
||||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||||
@@ -186,21 +185,32 @@ class UNetT(nn.Module):
|
|||||||
self.norm_out = RMSNorm(dim)
|
self.norm_out = RMSNorm(dim)
|
||||||
self.proj_out = nn.Linear(dim, mel_dim)
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
|
|
||||||
|
# `_cache_local` is lazily initialized on first inference-time cache write so that
|
||||||
|
# training models (which never touch the cache) stay deepcopy-friendly for EMA.
|
||||||
|
def _get_cache_local(self):
|
||||||
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
if cache is None:
|
||||||
|
cache = threading.local()
|
||||||
|
self.__dict__["_cache_local"] = cache
|
||||||
|
return cache
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text_cond(self):
|
def text_cond(self):
|
||||||
return getattr(self._cache_local, "text_cond", None)
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
return getattr(cache, "text_cond", None) if cache is not None else None
|
||||||
|
|
||||||
@text_cond.setter
|
@text_cond.setter
|
||||||
def text_cond(self, value):
|
def text_cond(self, value):
|
||||||
self._cache_local.text_cond = value
|
self._get_cache_local().text_cond = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text_uncond(self):
|
def text_uncond(self):
|
||||||
return getattr(self._cache_local, "text_uncond", None)
|
cache = self.__dict__.get("_cache_local")
|
||||||
|
return getattr(cache, "text_uncond", None) if cache is not None else None
|
||||||
|
|
||||||
@text_uncond.setter
|
@text_uncond.setter
|
||||||
def text_uncond(self, value):
|
def text_uncond(self, value):
|
||||||
self._cache_local.text_uncond = value
|
self._get_cache_local().text_uncond = value
|
||||||
|
|
||||||
def get_input_embed(
|
def get_input_embed(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user