mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-12-22 15:16:47 -08:00
refer to LyCORIS 1.9.0.dev6 add new option and module for training norm layer (Which is reported to be good for style)
531 lines
19 KiB
Python
531 lines
19 KiB
Python
import os
|
|
import re
|
|
|
|
import network
|
|
import network_lora
|
|
import network_hada
|
|
import network_ia3
|
|
import network_lokr
|
|
import network_full
|
|
import network_norm
|
|
|
|
import torch
|
|
from typing import Union
|
|
|
|
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
|
|
|
module_types = [
|
|
network_lora.ModuleTypeLora(),
|
|
network_hada.ModuleTypeHada(),
|
|
network_ia3.ModuleTypeIa3(),
|
|
network_lokr.ModuleTypeLokr(),
|
|
network_full.ModuleTypeFull(),
|
|
network_norm.ModuleTypeNorm(),
|
|
]
|
|
|
|
|
|
re_digits = re.compile(r"\d+")
|
|
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
|
re_compiled = {}
|
|
|
|
suffix_conversion = {
|
|
"attentions": {},
|
|
"resnets": {
|
|
"conv1": "in_layers_2",
|
|
"conv2": "out_layers_3",
|
|
"norm1": "in_layers_0",
|
|
"norm2": "out_layers_0",
|
|
"time_emb_proj": "emb_layers_1",
|
|
"conv_shortcut": "skip_connection",
|
|
}
|
|
}
|
|
|
|
|
|
def convert_diffusers_name_to_compvis(key, is_sd2):
|
|
def match(match_list, regex_text):
|
|
regex = re_compiled.get(regex_text)
|
|
if regex is None:
|
|
regex = re.compile(regex_text)
|
|
re_compiled[regex_text] = regex
|
|
|
|
r = re.match(regex, key)
|
|
if not r:
|
|
return False
|
|
|
|
match_list.clear()
|
|
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
|
return True
|
|
|
|
m = []
|
|
|
|
if match(m, r"lora_unet_conv_in(.*)"):
|
|
return f'diffusion_model_input_blocks_0_0{m[0]}'
|
|
|
|
if match(m, r"lora_unet_conv_out(.*)"):
|
|
return f'diffusion_model_out_2{m[0]}'
|
|
|
|
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
|
|
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
|
|
|
|
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
|
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
|
|
|
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
|
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
|
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
|
|
|
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
|
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
|
|
|
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
|
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
|
|
|
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
|
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
|
|
|
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
|
if is_sd2:
|
|
if 'mlp_fc1' in m[1]:
|
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
|
elif 'mlp_fc2' in m[1]:
|
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
|
else:
|
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
|
|
|
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
|
|
|
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
|
|
if 'mlp_fc1' in m[1]:
|
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
|
elif 'mlp_fc2' in m[1]:
|
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
|
else:
|
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
|
|
|
return key
|
|
|
|
|
|
def assign_network_names_to_compvis_modules(sd_model):
|
|
network_layer_mapping = {}
|
|
|
|
if shared.sd_model.is_sdxl:
|
|
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
|
if not hasattr(embedder, 'wrapped'):
|
|
continue
|
|
|
|
for name, module in embedder.wrapped.named_modules():
|
|
network_name = f'{i}_{name.replace(".", "_")}'
|
|
network_layer_mapping[network_name] = module
|
|
module.network_layer_name = network_name
|
|
else:
|
|
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
|
network_name = name.replace(".", "_")
|
|
network_layer_mapping[network_name] = module
|
|
module.network_layer_name = network_name
|
|
|
|
for name, module in shared.sd_model.model.named_modules():
|
|
network_name = name.replace(".", "_")
|
|
network_layer_mapping[network_name] = module
|
|
module.network_layer_name = network_name
|
|
|
|
sd_model.network_layer_mapping = network_layer_mapping
|
|
|
|
|
|
def load_network(name, network_on_disk):
|
|
net = network.Network(name, network_on_disk)
|
|
net.mtime = os.path.getmtime(network_on_disk.filename)
|
|
|
|
sd = sd_models.read_state_dict(network_on_disk.filename)
|
|
|
|
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
|
if not hasattr(shared.sd_model, 'network_layer_mapping'):
|
|
assign_network_names_to_compvis_modules(shared.sd_model)
|
|
|
|
keys_failed_to_match = {}
|
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
|
|
|
matched_networks = {}
|
|
|
|
for key_network, weight in sd.items():
|
|
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
|
|
|
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
|
|
|
if sd_module is None:
|
|
m = re_x_proj.match(key)
|
|
if m:
|
|
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
|
|
|
|
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
|
if sd_module is None and "lora_unet" in key_network_without_network_parts:
|
|
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
|
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
|
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
|
|
|
# some SD1 Loras also have correct compvis keys
|
|
if sd_module is None:
|
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
|
|
|
if sd_module is None:
|
|
keys_failed_to_match[key_network] = key
|
|
continue
|
|
|
|
if key not in matched_networks:
|
|
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
|
|
|
|
matched_networks[key].w[network_part] = weight
|
|
|
|
for key, weights in matched_networks.items():
|
|
net_module = None
|
|
for nettype in module_types:
|
|
net_module = nettype.create_module(net, weights)
|
|
if net_module is not None:
|
|
break
|
|
|
|
if net_module is None:
|
|
raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
|
|
|
|
net.modules[key] = net_module
|
|
|
|
if keys_failed_to_match:
|
|
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
|
|
|
return net
|
|
|
|
|
|
def purge_networks_from_memory():
|
|
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
|
|
name = next(iter(networks_in_memory))
|
|
networks_in_memory.pop(name, None)
|
|
|
|
devices.torch_gc()
|
|
|
|
|
|
|
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
|
already_loaded = {}
|
|
|
|
for net in loaded_networks:
|
|
if net.name in names:
|
|
already_loaded[net.name] = net
|
|
|
|
loaded_networks.clear()
|
|
|
|
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
|
if any(x is None for x in networks_on_disk):
|
|
list_available_networks()
|
|
|
|
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
|
|
|
failed_to_load_networks = []
|
|
|
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
|
net = already_loaded.get(name, None)
|
|
|
|
if network_on_disk is not None:
|
|
if net is None:
|
|
net = networks_in_memory.get(name)
|
|
|
|
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
|
try:
|
|
net = load_network(name, network_on_disk)
|
|
|
|
networks_in_memory.pop(name, None)
|
|
networks_in_memory[name] = net
|
|
except Exception as e:
|
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
|
continue
|
|
|
|
net.mentioned_name = name
|
|
|
|
network_on_disk.read_hash()
|
|
|
|
if net is None:
|
|
failed_to_load_networks.append(name)
|
|
print(f"Couldn't find network with name {name}")
|
|
continue
|
|
|
|
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
|
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
|
|
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
|
loaded_networks.append(net)
|
|
|
|
if failed_to_load_networks:
|
|
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
|
|
|
purge_networks_from_memory()
|
|
|
|
|
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
|
weights_backup = getattr(self, "network_weights_backup", None)
|
|
bias_backup = getattr(self, "network_bias_backup", None)
|
|
|
|
if weights_backup is None and bias_backup is None:
|
|
return
|
|
|
|
if weights_backup is not None:
|
|
if isinstance(self, torch.nn.MultiheadAttention):
|
|
self.in_proj_weight.copy_(weights_backup[0])
|
|
self.out_proj.weight.copy_(weights_backup[1])
|
|
else:
|
|
self.weight.copy_(weights_backup)
|
|
|
|
if bias_backup is not None:
|
|
self.bias.copy_(bias_backup)
|
|
|
|
|
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
|
"""
|
|
Applies the currently selected set of networks to the weights of torch layer self.
|
|
If weights already have this particular set of networks applied, does nothing.
|
|
If not, restores orginal weights from backup and alters weights according to networks.
|
|
"""
|
|
|
|
network_layer_name = getattr(self, 'network_layer_name', None)
|
|
if network_layer_name is None:
|
|
return
|
|
|
|
current_names = getattr(self, "network_current_names", ())
|
|
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
|
|
|
weights_backup = getattr(self, "network_weights_backup", None)
|
|
if weights_backup is None:
|
|
if isinstance(self, torch.nn.MultiheadAttention):
|
|
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
|
else:
|
|
weights_backup = self.weight.to(devices.cpu, copy=True)
|
|
|
|
self.network_weights_backup = weights_backup
|
|
|
|
bias_backup = getattr(self, "network_bias_backup", None)
|
|
if bias_backup is None and getattr(self, 'bias', None) is not None:
|
|
bias_backup = self.bias.to(devices.cpu, copy=True)
|
|
self.network_bias_backup = bias_backup
|
|
|
|
if current_names != wanted_names:
|
|
network_restore_weights_from_backup(self)
|
|
|
|
for net in loaded_networks:
|
|
module = net.modules.get(network_layer_name, None)
|
|
if module is not None and hasattr(self, 'weight'):
|
|
with torch.no_grad():
|
|
updown, ex_bias = module.calc_updown(self.weight)
|
|
|
|
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
|
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
|
|
|
self.weight += updown
|
|
if getattr(self, 'bias', None) is not None:
|
|
self.bias += ex_bias
|
|
continue
|
|
|
|
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
|
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
|
module_v = net.modules.get(network_layer_name + "_v_proj", None)
|
|
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
|
|
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
|
with torch.no_grad():
|
|
updown_q = module_q.calc_updown(self.in_proj_weight)
|
|
updown_k = module_k.calc_updown(self.in_proj_weight)
|
|
updown_v = module_v.calc_updown(self.in_proj_weight)
|
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
|
updown_out = module_out.calc_updown(self.out_proj.weight)
|
|
|
|
self.in_proj_weight += updown_qkv
|
|
self.out_proj.weight += updown_out
|
|
continue
|
|
|
|
if module is None:
|
|
continue
|
|
|
|
print(f'failed to calculate network weights for layer {network_layer_name}')
|
|
|
|
self.network_current_names = wanted_names
|
|
|
|
|
|
def network_forward(module, input, original_forward):
|
|
"""
|
|
Old way of applying Lora by executing operations during layer's forward.
|
|
Stacking many loras this way results in big performance degradation.
|
|
"""
|
|
|
|
if len(loaded_networks) == 0:
|
|
return original_forward(module, input)
|
|
|
|
input = devices.cond_cast_unet(input)
|
|
|
|
network_restore_weights_from_backup(module)
|
|
network_reset_cached_weight(module)
|
|
|
|
y = original_forward(module, input)
|
|
|
|
network_layer_name = getattr(module, 'network_layer_name', None)
|
|
for lora in loaded_networks:
|
|
module = lora.modules.get(network_layer_name, None)
|
|
if module is None:
|
|
continue
|
|
|
|
y = module.forward(input, y)
|
|
|
|
return y
|
|
|
|
|
|
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
|
self.network_current_names = ()
|
|
self.network_weights_backup = None
|
|
|
|
|
|
def network_Linear_forward(self, input):
|
|
if shared.opts.lora_functional:
|
|
return network_forward(self, input, torch.nn.Linear_forward_before_network)
|
|
|
|
network_apply_weights(self)
|
|
|
|
return torch.nn.Linear_forward_before_network(self, input)
|
|
|
|
|
|
def network_Linear_load_state_dict(self, *args, **kwargs):
|
|
network_reset_cached_weight(self)
|
|
|
|
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
|
|
|
|
|
|
def network_Conv2d_forward(self, input):
|
|
if shared.opts.lora_functional:
|
|
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
|
|
|
|
network_apply_weights(self)
|
|
|
|
return torch.nn.Conv2d_forward_before_network(self, input)
|
|
|
|
|
|
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
|
network_reset_cached_weight(self)
|
|
|
|
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
|
|
|
|
|
def network_GroupNorm_forward(self, input):
|
|
if shared.opts.lora_functional:
|
|
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
|
|
|
|
network_apply_weights(self)
|
|
|
|
return torch.nn.GroupNorm_forward_before_network(self, input)
|
|
|
|
|
|
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
|
network_reset_cached_weight(self)
|
|
|
|
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
|
|
|
|
|
|
def network_LayerNorm_forward(self, input):
|
|
if shared.opts.lora_functional:
|
|
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
|
|
|
|
network_apply_weights(self)
|
|
|
|
return torch.nn.LayerNorm_forward_before_network(self, input)
|
|
|
|
|
|
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
|
network_reset_cached_weight(self)
|
|
|
|
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
|
|
|
|
|
|
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
|
network_apply_weights(self)
|
|
|
|
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
|
|
|
|
|
|
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|
network_reset_cached_weight(self)
|
|
|
|
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
|
|
|
|
|
|
def list_available_networks():
|
|
available_networks.clear()
|
|
available_network_aliases.clear()
|
|
forbidden_network_aliases.clear()
|
|
available_network_hash_lookup.clear()
|
|
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
|
|
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
|
|
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
|
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
|
for filename in candidates:
|
|
if os.path.isdir(filename):
|
|
continue
|
|
|
|
name = os.path.splitext(os.path.basename(filename))[0]
|
|
try:
|
|
entry = network.NetworkOnDisk(name, filename)
|
|
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
|
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
|
|
continue
|
|
|
|
available_networks[name] = entry
|
|
|
|
if entry.alias in available_network_aliases:
|
|
forbidden_network_aliases[entry.alias.lower()] = 1
|
|
|
|
available_network_aliases[name] = entry
|
|
available_network_aliases[entry.alias] = entry
|
|
|
|
|
|
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
|
|
|
|
|
def infotext_pasted(infotext, params):
|
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
|
return # if the other extension is active, it will handle those fields, no need to do anything
|
|
|
|
added = []
|
|
|
|
for k in params:
|
|
if not k.startswith("AddNet Model "):
|
|
continue
|
|
|
|
num = k[13:]
|
|
|
|
if params.get("AddNet Module " + num) != "LoRA":
|
|
continue
|
|
|
|
name = params.get("AddNet Model " + num)
|
|
if name is None:
|
|
continue
|
|
|
|
m = re_network_name.match(name)
|
|
if m:
|
|
name = m.group(1)
|
|
|
|
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
|
|
|
added.append(f"<lora:{name}:{multiplier}>")
|
|
|
|
if added:
|
|
params["Prompt"] += "\n" + "".join(added)
|
|
|
|
|
|
available_networks = {}
|
|
available_network_aliases = {}
|
|
loaded_networks = []
|
|
networks_in_memory = {}
|
|
available_network_hash_lookup = {}
|
|
forbidden_network_aliases = {}
|
|
|
|
list_available_networks()
|