v1.1.20: refactor cache handling in DiT, MMDiT, and UNetT classes (lazyinit), to fix training bug (EMA deepcopy failure)

This commit is contained in:
SWivid
2026-04-20 15:28:10 +08:00
parent 650c177b14
commit 6f91022519
4 changed files with 46 additions and 16 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "f5-tts" name = "f5-tts"
version = "1.1.19" version = "1.1.20"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md" readme = "README.md"
license = {text = "MIT License"} license = {text = "MIT License"}

View File

@@ -202,7 +202,6 @@ class DiT(nn.Module):
average_upsampling=text_embedding_average_upsampling, average_upsampling=text_embedding_average_upsampling,
conv_layers=conv_layers, conv_layers=conv_layers,
) )
self._cache_local = threading.local() # thread-local storage for cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head) self.rotary_embed = RotaryEmbedding(dim_head)
@@ -235,21 +234,32 @@ class DiT(nn.Module):
self.initialize_weights() self.initialize_weights()
# `_cache_local` is lazily initialized on first inference-time cache write so that
# training models (which never touch the cache) stay deepcopy-friendly for EMA.
def _get_cache_local(self):
cache = self.__dict__.get("_cache_local")
if cache is None:
cache = threading.local()
self.__dict__["_cache_local"] = cache
return cache
@property @property
def text_cond(self): def text_cond(self):
return getattr(self._cache_local, "text_cond", None) cache = self.__dict__.get("_cache_local")
return getattr(cache, "text_cond", None) if cache is not None else None
@text_cond.setter @text_cond.setter
def text_cond(self, value): def text_cond(self, value):
self._cache_local.text_cond = value self._get_cache_local().text_cond = value
@property @property
def text_uncond(self): def text_uncond(self):
return getattr(self._cache_local, "text_uncond", None) cache = self.__dict__.get("_cache_local")
return getattr(cache, "text_uncond", None) if cache is not None else None
@text_uncond.setter @text_uncond.setter
def text_uncond(self, value): def text_uncond(self, value):
self._cache_local.text_uncond = value self._get_cache_local().text_uncond = value
def initialize_weights(self): def initialize_weights(self):
# Zero-out AdaLN layers in DiT blocks: # Zero-out AdaLN layers in DiT blocks:

View File

@@ -106,7 +106,6 @@ class MMDiT(nn.Module):
self.time_embed = TimestepEmbedding(dim) self.time_embed = TimestepEmbedding(dim)
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding) self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
self._cache_local = threading.local() # thread-local storage for cache
self.audio_embed = AudioEmbedding(mel_dim, dim) self.audio_embed = AudioEmbedding(mel_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head) self.rotary_embed = RotaryEmbedding(dim_head)
@@ -137,21 +136,32 @@ class MMDiT(nn.Module):
self.initialize_weights() self.initialize_weights()
# `_cache_local` is lazily initialized on first inference-time cache write so that
# training models (which never touch the cache) stay deepcopy-friendly for EMA.
def _get_cache_local(self):
cache = self.__dict__.get("_cache_local")
if cache is None:
cache = threading.local()
self.__dict__["_cache_local"] = cache
return cache
@property @property
def text_cond(self): def text_cond(self):
return getattr(self._cache_local, "text_cond", None) cache = self.__dict__.get("_cache_local")
return getattr(cache, "text_cond", None) if cache is not None else None
@text_cond.setter @text_cond.setter
def text_cond(self, value): def text_cond(self, value):
self._cache_local.text_cond = value self._get_cache_local().text_cond = value
@property @property
def text_uncond(self): def text_uncond(self):
return getattr(self._cache_local, "text_uncond", None) cache = self.__dict__.get("_cache_local")
return getattr(cache, "text_uncond", None) if cache is not None else None
@text_uncond.setter @text_uncond.setter
def text_uncond(self, value): def text_uncond(self, value):
self._cache_local.text_uncond = value self._get_cache_local().text_uncond = value
def initialize_weights(self): def initialize_weights(self):
# Zero-out AdaLN layers in MMDiT blocks: # Zero-out AdaLN layers in MMDiT blocks:

View File

@@ -135,7 +135,6 @@ class UNetT(nn.Module):
self.text_embed = TextEmbedding( self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
) )
self._cache_local = threading.local() # thread-local storage for cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head) self.rotary_embed = RotaryEmbedding(dim_head)
@@ -186,21 +185,32 @@ class UNetT(nn.Module):
self.norm_out = RMSNorm(dim) self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim) self.proj_out = nn.Linear(dim, mel_dim)
# `_cache_local` is lazily initialized on first inference-time cache write so that
# training models (which never touch the cache) stay deepcopy-friendly for EMA.
def _get_cache_local(self):
cache = self.__dict__.get("_cache_local")
if cache is None:
cache = threading.local()
self.__dict__["_cache_local"] = cache
return cache
@property @property
def text_cond(self): def text_cond(self):
return getattr(self._cache_local, "text_cond", None) cache = self.__dict__.get("_cache_local")
return getattr(cache, "text_cond", None) if cache is not None else None
@text_cond.setter @text_cond.setter
def text_cond(self, value): def text_cond(self, value):
self._cache_local.text_cond = value self._get_cache_local().text_cond = value
@property @property
def text_uncond(self): def text_uncond(self):
return getattr(self._cache_local, "text_uncond", None) cache = self.__dict__.get("_cache_local")
return getattr(cache, "text_uncond", None) if cache is not None else None
@text_uncond.setter @text_uncond.setter
def text_uncond(self, value): def text_uncond(self, value):
self._cache_local.text_uncond = value self._get_cache_local().text_uncond = value
def get_input_embed( def get_input_embed(
self, self,