mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-03-12 21:02:50 -07:00
Merge pull request #1270 from ZhikangNiu/main
- Use fused=True for AdamW by default - Warn on torch attention mask memory usage `if attn_backend == "torch" and attn_mask_enabled` --------- Co-authored-by: SWivid <swivid@qq.com>
This commit is contained in:
@@ -11,6 +11,7 @@ d - dimension
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -452,6 +453,12 @@ class AttnProcessor:
|
|||||||
):
|
):
|
||||||
if attn_backend == "flash_attn":
|
if attn_backend == "flash_attn":
|
||||||
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
||||||
|
if attn_backend == "torch" and attn_mask_enabled:
|
||||||
|
warnings.warn(
|
||||||
|
"attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. "
|
||||||
|
"Please switch attn_backend to 'flash_attn'.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
self.pe_attn_head = pe_attn_head
|
self.pe_attn_head = pe_attn_head
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
@@ -557,6 +564,12 @@ class JointAttnProcessor:
|
|||||||
):
|
):
|
||||||
if attn_backend == "flash_attn":
|
if attn_backend == "flash_attn":
|
||||||
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
||||||
|
if attn_backend == "torch" and attn_mask_enabled:
|
||||||
|
warnings.warn(
|
||||||
|
"attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. "
|
||||||
|
"Please switch attn_backend to 'flash_attn'.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
self.attn_mask_enabled = attn_mask_enabled
|
self.attn_mask_enabled = attn_mask_enabled
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ class Trainer:
|
|||||||
"grad_accumulation_steps": grad_accumulation_steps,
|
"grad_accumulation_steps": grad_accumulation_steps,
|
||||||
"max_grad_norm": max_grad_norm,
|
"max_grad_norm": max_grad_norm,
|
||||||
"noise_scheduler": noise_scheduler,
|
"noise_scheduler": noise_scheduler,
|
||||||
|
"bnb_optimizer": bnb_optimizer,
|
||||||
}
|
}
|
||||||
model_cfg_dict["gpus"] = self.accelerator.num_processes
|
model_cfg_dict["gpus"] = self.accelerator.num_processes
|
||||||
self.accelerator.init_trackers(
|
self.accelerator.init_trackers(
|
||||||
@@ -139,7 +140,7 @@ class Trainer:
|
|||||||
|
|
||||||
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
||||||
else:
|
else:
|
||||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
self.optimizer = AdamW(model.parameters(), lr=learning_rate, fused=True)
|
||||||
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Reference in New Issue
Block a user