mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-12-22 23:26:48 -08:00
Merge branch 'master' into fix-vram
This commit is contained in:
@@ -8,7 +8,7 @@ import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import shared, modelloader
|
||||
from modules import devices, modelloader
|
||||
from modules.bsrgan_model_arch import RRDBNet
|
||||
from modules.paths import models_path
|
||||
|
||||
@@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
model.to(devices.device_bsrgan)
|
||||
torch.cuda.empty_cache()
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(shared.device)
|
||||
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
|
||||
@@ -69,10 +69,14 @@ def setup_model(dirname):
|
||||
|
||||
self.net = net
|
||||
self.face_helper = face_helper
|
||||
self.net.to(devices.device_codeformer)
|
||||
|
||||
return net, face_helper
|
||||
|
||||
def send_model_to(self, device):
|
||||
self.net.to(device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def restore(self, np_image, w=None):
|
||||
np_image = np_image[:, :, ::-1]
|
||||
|
||||
@@ -82,6 +86,8 @@ def setup_model(dirname):
|
||||
if self.net is None or self.face_helper is None:
|
||||
return np_image
|
||||
|
||||
self.send_model_to(devices.device_codeformer)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
self.face_helper.read_image(np_image)
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
@@ -97,7 +103,7 @@ def setup_model(dirname):
|
||||
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
devices.torch_gc()
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as error:
|
||||
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
@@ -113,10 +119,10 @@ def setup_model(dirname):
|
||||
if original_resolution != restored_img.shape[0:2]:
|
||||
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.net = None
|
||||
self.face_helper = None
|
||||
devices.torch_gc()
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
return restored_img
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import gc
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
from modules import errors
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
has_mps = getattr(torch, 'has_mps', False)
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
@@ -33,8 +35,7 @@ def enable_tf32():
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
device = get_optimal_device()
|
||||
device_codeformer = cpu if has_mps else device
|
||||
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
dtype = torch.float16
|
||||
|
||||
def randn(seed, shape):
|
||||
@@ -58,3 +59,11 @@ def randn_without_seed(shape):
|
||||
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def autocast():
|
||||
from modules import shared
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
@@ -6,8 +6,7 @@ from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.esrgam_model_arch as arch
|
||||
from modules import shared, modelloader, images
|
||||
from modules.devices import has_mps
|
||||
from modules import shared, modelloader, images, devices
|
||||
from modules.paths import models_path
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import opts
|
||||
@@ -97,7 +96,7 @@ class UpscalerESRGAN(Upscaler):
|
||||
model = self.load_model(selected_model)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
model.to(devices.device_esrgan)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
@@ -112,7 +111,7 @@ class UpscalerESRGAN(Upscaler):
|
||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||
return None
|
||||
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
||||
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
||||
@@ -127,7 +126,7 @@ def upscale_without_tiling(model, img):
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(shared.device)
|
||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
|
||||
@@ -21,7 +21,7 @@ def gfpgann():
|
||||
global loaded_gfpgan_model
|
||||
global model_path
|
||||
if loaded_gfpgan_model is not None:
|
||||
loaded_gfpgan_model.gfpgan.to(shared.device)
|
||||
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||
return loaded_gfpgan_model
|
||||
|
||||
if gfpgan_constructor is None:
|
||||
@@ -37,25 +37,32 @@ def gfpgann():
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
model.gfpgan.to(shared.device)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def send_model_to(model, device):
|
||||
model.gfpgan.to(device)
|
||||
model.face_helper.face_det.to(device)
|
||||
model.face_helper.face_parse.to(device)
|
||||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
global loaded_gfpgan_model
|
||||
model = gfpgann()
|
||||
if model is None:
|
||||
return np_image
|
||||
|
||||
send_model_to(model, devices.device_gfpgan)
|
||||
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
model.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
del model
|
||||
loaded_gfpgan_model = None
|
||||
devices.torch_gc()
|
||||
send_model_to(model, devices.cpu)
|
||||
|
||||
return np_image
|
||||
|
||||
|
||||
@@ -287,6 +287,25 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||
if seed is not None:
|
||||
x = x.replace("[seed]", str(seed))
|
||||
|
||||
if p is not None:
|
||||
x = x.replace("[steps]", str(p.steps))
|
||||
x = x.replace("[cfg]", str(p.cfg_scale))
|
||||
x = x.replace("[width]", str(p.width))
|
||||
x = x.replace("[height]", str(p.height))
|
||||
|
||||
#currently disabled if using the save button, will work otherwise
|
||||
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
|
||||
if hasattr(p, "styles"):
|
||||
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
|
||||
|
||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||
|
||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
|
||||
|
||||
# Apply [prompt] at last. Because it may contain any replacement word.^M
|
||||
if prompt is not None:
|
||||
x = x.replace("[prompt]", sanitize_filename_part(prompt))
|
||||
if "[prompt_no_styles]" in x:
|
||||
@@ -295,7 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||
if len(style) > 0:
|
||||
style_parts = [y for y in style.split("{prompt}")]
|
||||
for part in style_parts:
|
||||
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
||||
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
||||
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
||||
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
|
||||
|
||||
@@ -306,24 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||
words = ["empty"]
|
||||
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
||||
|
||||
if p is not None:
|
||||
x = x.replace("[steps]", str(p.steps))
|
||||
x = x.replace("[cfg]", str(p.cfg_scale))
|
||||
x = x.replace("[width]", str(p.width))
|
||||
x = x.replace("[height]", str(p.height))
|
||||
|
||||
#currently disabled if using the save button, will work otherwise
|
||||
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
|
||||
if hasattr(p, "styles"):
|
||||
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
|
||||
|
||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||
|
||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
|
||||
|
||||
if cmd_opts.hide_ui_dir_config:
|
||||
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
|
||||
|
||||
@@ -379,7 +380,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||
|
||||
if save_to_dirs:
|
||||
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt)
|
||||
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
|
||||
path = os.path.join(path, dirname)
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
@@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args):
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
save_normally = output_dir == ''
|
||||
|
||||
p.do_not_save_grid = True
|
||||
p.do_not_save_samples = True
|
||||
p.do_not_save_samples = not save_normally
|
||||
|
||||
state.job_count = len(images) * p.n_iter
|
||||
|
||||
@@ -48,7 +50,8 @@ def process_batch(p, input_dir, output_dir, args):
|
||||
left, right = os.path.splitext(filename)
|
||||
filename = f"{left}-{n}{right}"
|
||||
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
if not save_normally:
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
@@ -126,4 +129,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||
if opts.samples_log_stdout:
|
||||
print(generation_info_js)
|
||||
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -85,7 +84,7 @@ class StableDiffusionProcessing:
|
||||
self.s_tmin = opts.s_tmin
|
||||
self.s_tmax = float('inf') # not representable as a standard ui option
|
||||
self.s_noise = opts.s_noise
|
||||
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
self.subseed_strength = 0
|
||||
@@ -249,9 +248,16 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
return x
|
||||
|
||||
|
||||
def get_fixed_seed(seed):
|
||||
if seed is None or seed == '' or seed == -1:
|
||||
return int(random.randrange(4294967294))
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
def fix_seed(p):
|
||||
p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed
|
||||
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed
|
||||
p.seed = get_fixed_seed(p.seed)
|
||||
p.subseed = get_fixed_seed(p.subseed)
|
||||
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
||||
@@ -290,10 +296,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
assert(len(p.prompt) > 0)
|
||||
else:
|
||||
assert p.prompt is not None
|
||||
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
fix_seed(p)
|
||||
seed = get_fixed_seed(p.seed)
|
||||
subseed = get_fixed_seed(p.subseed)
|
||||
|
||||
if p.outpath_samples is not None:
|
||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||
@@ -312,15 +319,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
else:
|
||||
all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
||||
|
||||
if type(p.seed) == list:
|
||||
all_seeds = p.seed
|
||||
if type(seed) == list:
|
||||
all_seeds = seed
|
||||
else:
|
||||
all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
||||
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
||||
|
||||
if type(p.subseed) == list:
|
||||
all_subseeds = p.subseed
|
||||
if type(subseed) == list:
|
||||
all_subseeds = subseed
|
||||
else:
|
||||
all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))]
|
||||
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
|
||||
|
||||
def infotext(iteration=0, position_in_batch=0):
|
||||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
||||
@@ -330,10 +337,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||
p.init(all_prompts, all_seeds, all_subseeds)
|
||||
|
||||
with torch.no_grad():
|
||||
with devices.autocast():
|
||||
p.init(all_prompts, all_seeds, all_subseeds)
|
||||
|
||||
if state.job_count == -1:
|
||||
state.job_count = p.n_iter
|
||||
@@ -352,8 +359,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
||||
#c = p.sd_model.get_learned_conditioning(prompts)
|
||||
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
|
||||
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
|
||||
with devices.autocast():
|
||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
||||
c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
for comment in model_hijack.comments:
|
||||
@@ -362,13 +370,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||
with devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||
|
||||
if state.interrupted:
|
||||
|
||||
# if we are interruped, sample returns just noise
|
||||
# use the image collected previously in sampler loop
|
||||
samples_ddim = shared.state.current_latent
|
||||
|
||||
samples_ddim = samples_ddim.to(devices.dtype)
|
||||
|
||||
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
@@ -394,6 +406,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
||||
|
||||
x_sample = modules.face_restoration.restore_faces(x_sample)
|
||||
devices.torch_gc()
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
@@ -530,7 +543,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
# GC now before running the next img2img to prevent running out of memory
|
||||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
||||
|
||||
return samples
|
||||
|
||||
@@ -1,19 +1,7 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
re_prompt = re.compile(r'''
|
||||
(.*?)
|
||||
\[
|
||||
([^]:]+):
|
||||
(?:([^]:]*):)?
|
||||
([0-9]*\.?[0-9]+)
|
||||
]
|
||||
|
|
||||
(.+)
|
||||
''', re.X)
|
||||
import lark
|
||||
|
||||
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
||||
# will be represented with prompt_schedule like this (assuming steps=100):
|
||||
@@ -23,71 +11,96 @@ re_prompt = re.compile(r'''
|
||||
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
||||
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
||||
|
||||
schedule_parser = lark.Lark(r"""
|
||||
!start: (prompt | /[][():]/+)*
|
||||
prompt: (emphasized | scheduled | plain | WHITESPACE)*
|
||||
!emphasized: "(" prompt ")"
|
||||
| "(" prompt ":" prompt ")"
|
||||
| "[" prompt "]"
|
||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
||||
WHITESPACE: /\s+/
|
||||
plain: /([^\\\[\]():]|\\.)+/
|
||||
%import common.SIGNED_NUMBER -> NUMBER
|
||||
""")
|
||||
|
||||
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
res = []
|
||||
cache = {}
|
||||
"""
|
||||
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
||||
>>> g("test")
|
||||
[[10, 'test']]
|
||||
>>> g("a [b:3]")
|
||||
[[3, 'a '], [10, 'a b']]
|
||||
>>> g("a [b: 3]")
|
||||
[[3, 'a '], [10, 'a b']]
|
||||
>>> g("a [[[b]]:2]")
|
||||
[[2, 'a '], [10, 'a [[b]]']]
|
||||
>>> g("[(a:2):3]")
|
||||
[[3, ''], [10, '(a:2)']]
|
||||
>>> g("a [b : c : 1] d")
|
||||
[[1, 'a b d'], [10, 'a c d']]
|
||||
>>> g("a[b:[c:d:2]:1]e")
|
||||
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
|
||||
>>> g("a [unbalanced")
|
||||
[[10, 'a [unbalanced']]
|
||||
>>> g("a [b:.5] c")
|
||||
[[5, 'a c'], [10, 'a b c']]
|
||||
>>> g("a [{b|d{:.5] c") # not handling this right now
|
||||
[[5, 'a c'], [10, 'a {b|d{ c']]
|
||||
>>> g("((a][:b:c [d:3]")
|
||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||
"""
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_schedule: list[list[str | int]] = [[steps, ""]]
|
||||
def collect_steps(steps, tree):
|
||||
l = [steps]
|
||||
class CollectSteps(lark.Visitor):
|
||||
def scheduled(self, tree):
|
||||
tree.children[-1] = float(tree.children[-1])
|
||||
if tree.children[-1] < 1:
|
||||
tree.children[-1] *= steps
|
||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||
l.append(tree.children[-1])
|
||||
CollectSteps().visit(tree)
|
||||
return sorted(set(l))
|
||||
|
||||
cached = cache.get(prompt, None)
|
||||
if cached is not None:
|
||||
res.append(cached)
|
||||
continue
|
||||
def at_step(step, tree):
|
||||
class AtStep(lark.Transformer):
|
||||
def scheduled(self, args):
|
||||
before, after, _, when = args
|
||||
yield before or () if step <= when else after
|
||||
def start(self, args):
|
||||
def flatten(x):
|
||||
if type(x) == str:
|
||||
yield x
|
||||
else:
|
||||
for gen in x:
|
||||
yield from flatten(gen)
|
||||
return ''.join(flatten(args))
|
||||
def plain(self, args):
|
||||
yield args[0].value
|
||||
def __default__(self, data, children, meta):
|
||||
for child in children:
|
||||
yield from child
|
||||
return AtStep().transform(tree)
|
||||
|
||||
for m in re_prompt.finditer(prompt):
|
||||
plaintext = m.group(1) if m.group(5) is None else m.group(5)
|
||||
concept_from = m.group(2)
|
||||
concept_to = m.group(3)
|
||||
if concept_to is None:
|
||||
concept_to = concept_from
|
||||
concept_from = ""
|
||||
swap_position = float(m.group(4)) if m.group(4) is not None else None
|
||||
def get_schedule(prompt):
|
||||
try:
|
||||
tree = schedule_parser.parse(prompt)
|
||||
except lark.exceptions.LarkError as e:
|
||||
if 0:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return [[steps, prompt]]
|
||||
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
|
||||
|
||||
if swap_position is not None:
|
||||
if swap_position < 1:
|
||||
swap_position = swap_position * steps
|
||||
swap_position = int(min(swap_position, steps))
|
||||
|
||||
swap_index = None
|
||||
found_exact_index = False
|
||||
for i in range(len(prompt_schedule)):
|
||||
end_step = prompt_schedule[i][0]
|
||||
prompt_schedule[i][1] += plaintext
|
||||
|
||||
if swap_position is not None and swap_index is None:
|
||||
if swap_position == end_step:
|
||||
swap_index = i
|
||||
found_exact_index = True
|
||||
|
||||
if swap_position < end_step:
|
||||
swap_index = i
|
||||
|
||||
if swap_index is not None:
|
||||
if not found_exact_index:
|
||||
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
|
||||
|
||||
for i in range(len(prompt_schedule)):
|
||||
end_step = prompt_schedule[i][0]
|
||||
must_replace = swap_position < end_step
|
||||
|
||||
prompt_schedule[i][1] += concept_to if must_replace else concept_from
|
||||
|
||||
res.append(prompt_schedule)
|
||||
cache[prompt] = prompt_schedule
|
||||
#for t in prompt_schedule:
|
||||
# print(t)
|
||||
|
||||
return res
|
||||
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
|
||||
return [promptdict[prompt] for prompt in prompts]
|
||||
|
||||
|
||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
|
||||
|
||||
|
||||
def get_learned_conditioning(prompts, steps):
|
||||
|
||||
def get_learned_conditioning(model, prompts, steps):
|
||||
res = []
|
||||
|
||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
||||
@@ -101,7 +114,7 @@ def get_learned_conditioning(prompts, steps):
|
||||
continue
|
||||
|
||||
texts = [x[1] for x in prompt_schedule]
|
||||
conds = shared.sd_model.get_learned_conditioning(texts)
|
||||
conds = model.get_learned_conditioning(texts)
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
||||
@@ -114,12 +127,13 @@ def get_learned_conditioning(prompts, steps):
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
||||
res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype)
|
||||
param = c.schedules[0][0].cond
|
||||
res = torch.zeros(c.shape, device=param.device, dtype=param.dtype)
|
||||
for i, cond_schedule in enumerate(c.schedules):
|
||||
target_index = 0
|
||||
for curret_index, (end_at, cond) in enumerate(cond_schedule):
|
||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
||||
if current_step <= end_at:
|
||||
target_index = curret_index
|
||||
target_index = current
|
||||
break
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
|
||||
@@ -157,23 +171,26 @@ def parse_prompt_attention(text):
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
|
||||
Example:
|
||||
|
||||
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
|
||||
|
||||
produces:
|
||||
|
||||
[
|
||||
['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]
|
||||
]
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
@@ -215,4 +232,19 @@ def parse_prompt_attention(text):
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
|
||||
else:
|
||||
import torch # doctest faster
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import shared, modelloader
|
||||
from modules import devices, modelloader
|
||||
from modules.paths import models_path
|
||||
from modules.scunet_model_arch import SCUNet as net
|
||||
|
||||
@@ -51,12 +51,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
if model is None:
|
||||
return img
|
||||
|
||||
device = shared.device
|
||||
device = devices.device_scunet
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(shared.device)
|
||||
img = img.unsqueeze(0).to(device)
|
||||
|
||||
img = img.to(device)
|
||||
with torch.no_grad():
|
||||
@@ -69,7 +69,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
device = shared.device
|
||||
device = devices.device_scunet
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
||||
|
||||
@@ -127,7 +127,7 @@ class VanillaStableDiffusionSampler:
|
||||
return res
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta or opts.eta_ddim
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ddim
|
||||
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
|
||||
@@ -12,7 +12,7 @@ import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.sd_models
|
||||
import modules.styles
|
||||
from modules.devices import get_optimal_device
|
||||
import modules.devices as devices
|
||||
from modules.paths import script_path, sd_path
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
@@ -46,6 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
@@ -54,6 +55,7 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="color-sketch")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
@@ -63,7 +65,11 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print
|
||||
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
device = get_optimal_device()
|
||||
|
||||
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
|
||||
|
||||
device = devices.device
|
||||
|
||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||
@@ -183,7 +189,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
@@ -195,7 +201,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
@@ -204,7 +210,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||
@@ -224,6 +230,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
|
||||
@@ -9,6 +9,9 @@ from torchvision import transforms
|
||||
import random
|
||||
import tqdm
|
||||
from modules import devices
|
||||
import re
|
||||
|
||||
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
@@ -38,8 +41,8 @@ class PersonalizedBase(Dataset):
|
||||
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
|
||||
filename = os.path.basename(path)
|
||||
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
|
||||
filename_tokens = [token for token in filename_tokens if token.isalpha()]
|
||||
filename_tokens = os.path.splitext(filename)[0]
|
||||
filename_tokens = re_tag.findall(filename_tokens)
|
||||
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
@@ -26,7 +26,9 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
|
||||
if process_caption:
|
||||
caption = "-" + shared.interrogator.generate_caption(image)
|
||||
else:
|
||||
caption = ""
|
||||
caption = filename
|
||||
caption = os.path.splitext(caption)[0]
|
||||
caption = os.path.basename(caption)
|
||||
|
||||
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
|
||||
subindex[0] += 1
|
||||
|
||||
@@ -164,7 +164,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
||||
|
||||
if save_embedding_every > 0:
|
||||
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||
|
||||
@@ -48,5 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||
if opts.samples_log_stdout:
|
||||
print(generation_info_js)
|
||||
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
||||
reuse_symbol = '\u267b\ufe0f' # ♻️
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
paste_symbol = '\u2199\ufe0f' # ↙
|
||||
folder_symbol = '\uD83D\uDCC2'
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
|
||||
def plaintext_to_html(text):
|
||||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||
@@ -196,6 +196,11 @@ def wrap_gradio_call(func, extra_outputs=None):
|
||||
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.2f}s"
|
||||
if (elapsed_m > 0):
|
||||
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||
@@ -210,7 +215,7 @@ def wrap_gradio_call(func, extra_outputs=None):
|
||||
vram_html = ''
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
@@ -386,14 +391,22 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||
outputs=[seed, dummy_component]
|
||||
)
|
||||
|
||||
|
||||
def update_token_counter(text, steps):
|
||||
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
|
||||
try:
|
||||
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
|
||||
except Exception:
|
||||
# a parsing error can happen here during typing, and we don't want to bother the user with
|
||||
# messages related to it in console
|
||||
prompt_schedules = [[[steps, text]]]
|
||||
|
||||
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
||||
prompts = [prompt_text for step,prompt_text in flat_prompts]
|
||||
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
||||
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
|
||||
style_class = ' class="red"' if (token_count > max_length) else ""
|
||||
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
||||
|
||||
|
||||
def create_toprow(is_img2img):
|
||||
id_part = "img2img" if is_img2img else "txt2img"
|
||||
|
||||
@@ -636,7 +649,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
|
||||
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
||||
with gr.TabItem('img2img', id='img2img'):
|
||||
init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil")
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint'):
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
|
||||
@@ -658,7 +671,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
|
||||
with gr.TabItem('Batch img2img', id='batch'):
|
||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.{hidden}</p>")
|
||||
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
|
||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
|
||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user