diff --git a/pyproject.toml b/pyproject.toml index 93775b4..80423d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts" -version = "1.1.19" +version = "1.1.20" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 9482972..e463c95 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -202,7 +202,6 @@ class DiT(nn.Module): average_upsampling=text_embedding_average_upsampling, conv_layers=conv_layers, ) - self._cache_local = threading.local() # thread-local storage for cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -235,21 +234,32 @@ class DiT(nn.Module): self.initialize_weights() + # `_cache_local` is lazily initialized on first inference-time cache write so that + # training models (which never touch the cache) stay deepcopy-friendly for EMA. + def _get_cache_local(self): + cache = self.__dict__.get("_cache_local") + if cache is None: + cache = threading.local() + self.__dict__["_cache_local"] = cache + return cache + @property def text_cond(self): - return getattr(self._cache_local, "text_cond", None) + cache = self.__dict__.get("_cache_local") + return getattr(cache, "text_cond", None) if cache is not None else None @text_cond.setter def text_cond(self, value): - self._cache_local.text_cond = value + self._get_cache_local().text_cond = value @property def text_uncond(self): - return getattr(self._cache_local, "text_uncond", None) + cache = self.__dict__.get("_cache_local") + return getattr(cache, "text_uncond", None) if cache is not None else None @text_uncond.setter def text_uncond(self, value): - self._cache_local.text_uncond = value + self._get_cache_local().text_uncond = value def initialize_weights(self): # Zero-out AdaLN layers in DiT blocks: diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index 451dd19..262ad0d 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -106,7 +106,6 @@ class MMDiT(nn.Module): self.time_embed = TimestepEmbedding(dim) self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding) - self._cache_local = threading.local() # thread-local storage for cache self.audio_embed = AudioEmbedding(mel_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -137,21 +136,32 @@ class MMDiT(nn.Module): self.initialize_weights() + # `_cache_local` is lazily initialized on first inference-time cache write so that + # training models (which never touch the cache) stay deepcopy-friendly for EMA. + def _get_cache_local(self): + cache = self.__dict__.get("_cache_local") + if cache is None: + cache = threading.local() + self.__dict__["_cache_local"] = cache + return cache + @property def text_cond(self): - return getattr(self._cache_local, "text_cond", None) + cache = self.__dict__.get("_cache_local") + return getattr(cache, "text_cond", None) if cache is not None else None @text_cond.setter def text_cond(self, value): - self._cache_local.text_cond = value + self._get_cache_local().text_cond = value @property def text_uncond(self): - return getattr(self._cache_local, "text_uncond", None) + cache = self.__dict__.get("_cache_local") + return getattr(cache, "text_uncond", None) if cache is not None else None @text_uncond.setter def text_uncond(self, value): - self._cache_local.text_uncond = value + self._get_cache_local().text_uncond = value def initialize_weights(self): # Zero-out AdaLN layers in MMDiT blocks: diff --git a/src/f5_tts/model/backbones/unett.py b/src/f5_tts/model/backbones/unett.py index 1aff1ff..b3b5513 100644 --- a/src/f5_tts/model/backbones/unett.py +++ b/src/f5_tts/model/backbones/unett.py @@ -135,7 +135,6 @@ class UNetT(nn.Module): self.text_embed = TextEmbedding( text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers ) - self._cache_local = threading.local() # thread-local storage for cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -186,21 +185,32 @@ class UNetT(nn.Module): self.norm_out = RMSNorm(dim) self.proj_out = nn.Linear(dim, mel_dim) + # `_cache_local` is lazily initialized on first inference-time cache write so that + # training models (which never touch the cache) stay deepcopy-friendly for EMA. + def _get_cache_local(self): + cache = self.__dict__.get("_cache_local") + if cache is None: + cache = threading.local() + self.__dict__["_cache_local"] = cache + return cache + @property def text_cond(self): - return getattr(self._cache_local, "text_cond", None) + cache = self.__dict__.get("_cache_local") + return getattr(cache, "text_cond", None) if cache is not None else None @text_cond.setter def text_cond(self, value): - self._cache_local.text_cond = value + self._get_cache_local().text_cond = value @property def text_uncond(self): - return getattr(self._cache_local, "text_uncond", None) + cache = self.__dict__.get("_cache_local") + return getattr(cache, "text_uncond", None) if cache is not None else None @text_uncond.setter def text_uncond(self, value): - self._cache_local.text_uncond = value + self._get_cache_local().text_uncond = value def get_input_embed( self,