mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-03-12 21:02:50 -07:00
Merge pull request #1270 from ZhikangNiu/main
- Use fused=True for AdamW by default - Warn on torch attention mask memory usage `if attn_backend == "torch" and attn_mask_enabled` --------- Co-authored-by: SWivid <swivid@qq.com>
This commit is contained in:
@@ -11,6 +11,7 @@ d - dimension
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -452,6 +453,12 @@ class AttnProcessor:
|
||||
):
|
||||
if attn_backend == "flash_attn":
|
||||
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
||||
if attn_backend == "torch" and attn_mask_enabled:
|
||||
warnings.warn(
|
||||
"attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. "
|
||||
"Please switch attn_backend to 'flash_attn'.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.pe_attn_head = pe_attn_head
|
||||
self.attn_backend = attn_backend
|
||||
@@ -557,6 +564,12 @@ class JointAttnProcessor:
|
||||
):
|
||||
if attn_backend == "flash_attn":
|
||||
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
||||
if attn_backend == "torch" and attn_mask_enabled:
|
||||
warnings.warn(
|
||||
"attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. "
|
||||
"Please switch attn_backend to 'flash_attn'.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
self.attn_mask_enabled = attn_mask_enabled
|
||||
|
||||
@@ -85,6 +85,7 @@ class Trainer:
|
||||
"grad_accumulation_steps": grad_accumulation_steps,
|
||||
"max_grad_norm": max_grad_norm,
|
||||
"noise_scheduler": noise_scheduler,
|
||||
"bnb_optimizer": bnb_optimizer,
|
||||
}
|
||||
model_cfg_dict["gpus"] = self.accelerator.num_processes
|
||||
self.accelerator.init_trackers(
|
||||
@@ -139,7 +140,7 @@ class Trainer:
|
||||
|
||||
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
||||
else:
|
||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate, fused=True)
|
||||
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user