mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-12-22 15:16:47 -08:00
Merge branch 'master' into fix-vram
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import gc
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
from modules import errors
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
has_mps = getattr(torch, 'has_mps', False)
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
@@ -33,8 +35,7 @@ def enable_tf32():
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
device = get_optimal_device()
|
||||
device_codeformer = cpu if has_mps else device
|
||||
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
dtype = torch.float16
|
||||
|
||||
def randn(seed, shape):
|
||||
@@ -58,3 +59,11 @@ def randn_without_seed(shape):
|
||||
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def autocast():
|
||||
from modules import shared
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
Reference in New Issue
Block a user