Merge pull request #1270 from ZhikangNiu/main

- Use fused=True for AdamW by default

- Warn on torch attention mask memory usage `if attn_backend == "torch" and attn_mask_enabled`

---------

Co-authored-by: SWivid <swivid@qq.com>
This commit is contained in:
Zhikang Niu-SII
2026-03-04 19:31:52 +08:00
committed by GitHub
parent ab75dc2837
commit b5ab1afa16
2 changed files with 15 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ d - dimension
from __future__ import annotations from __future__ import annotations
import math import math
import warnings
from typing import Optional from typing import Optional
import torch import torch
@@ -452,6 +453,12 @@ class AttnProcessor:
): ):
if attn_backend == "flash_attn": if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first." assert is_package_available("flash_attn"), "Please install flash-attn first."
if attn_backend == "torch" and attn_mask_enabled:
warnings.warn(
"attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. "
"Please switch attn_backend to 'flash_attn'.",
UserWarning,
)
self.pe_attn_head = pe_attn_head self.pe_attn_head = pe_attn_head
self.attn_backend = attn_backend self.attn_backend = attn_backend
@@ -557,6 +564,12 @@ class JointAttnProcessor:
): ):
if attn_backend == "flash_attn": if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first." assert is_package_available("flash_attn"), "Please install flash-attn first."
if attn_backend == "torch" and attn_mask_enabled:
warnings.warn(
"attn_mask_enabled=True with attn_backend='torch' can consume large GPU memory. "
"Please switch attn_backend to 'flash_attn'.",
UserWarning,
)
self.attn_backend = attn_backend self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled self.attn_mask_enabled = attn_mask_enabled

View File

@@ -85,6 +85,7 @@ class Trainer:
"grad_accumulation_steps": grad_accumulation_steps, "grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm, "max_grad_norm": max_grad_norm,
"noise_scheduler": noise_scheduler, "noise_scheduler": noise_scheduler,
"bnb_optimizer": bnb_optimizer,
} }
model_cfg_dict["gpus"] = self.accelerator.num_processes model_cfg_dict["gpus"] = self.accelerator.num_processes
self.accelerator.init_trackers( self.accelerator.init_trackers(
@@ -139,7 +140,7 @@ class Trainer:
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
else: else:
self.optimizer = AdamW(model.parameters(), lr=learning_rate) self.optimizer = AdamW(model.parameters(), lr=learning_rate, fused=True)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
@property @property