File size: 5,967 Bytes
19a1abb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
from lora_diffusion.cli_lora_add import *
from lora_diffusion.lora import *
from lora_diffusion.to_ckpt_v2 import *
def monkeypatch_or_replace_safeloras(models, safeloras):
loras = parse_safeloras(safeloras)
for name, (lora, ranks, target) in loras.items():
model = getattr(models, name, None)
if not model:
print(f"No model provided for {name}, contained in Lora")
continue
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
def parse_safeloras(
safeloras,
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
"""
Converts a loaded safetensor file that contains a set of module Loras
into Parameters and other information
Output is a dictionary of {
"module name": (
[list of weights],
[list of ranks],
target_replacement_modules
)
}
"""
loras = {}
# metadata = safeloras.metadata()
metadata = safeloras['metadata']
safeloras_ = safeloras['weights']
get_name = lambda k: k.split(":")[0]
keys = list(safeloras_.keys())
keys.sort(key=get_name)
for name, module_keys in groupby(keys, get_name):
info = metadata.get(name)
if not info:
raise ValueError(
f"Tensor {name} has no metadata - is this a Lora safetensor?"
)
# Skip Textual Inversion embeds
if info == EMBED_FLAG:
continue
# Handle Loras
# Extract the targets
target = json.loads(info)
# Build the result lists - Python needs us to preallocate lists to insert into them
module_keys = list(module_keys)
ranks = [4] * (len(module_keys) // 2)
weights = [None] * len(module_keys)
for key in module_keys:
# Split the model name and index out of the key
_, idx, direction = key.split(":")
idx = int(idx)
# Add the rank
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
# Insert the weight into the list
idx = idx * 2 + (1 if direction == "down" else 0)
# weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
weights[idx] = nn.parameter.Parameter(safeloras_[key])
loras[name] = (weights, ranks, target)
return loras
def parse_safeloras_embeds(
safeloras,
) -> Dict[str, torch.Tensor]:
"""
Converts a loaded safetensor file that contains Textual Inversion embeds into
a dictionary of embed_token: Tensor
"""
embeds = {}
metadata = safeloras['metadata']
safeloras_ = safeloras['weights']
for key in safeloras_.keys():
# Only handle Textual Inversion embeds
meta=None
if key in metadata:
meta = metadata[key]
if not meta or meta != EMBED_FLAG:
continue
embeds[key] = safeloras_[key]
return embeds
def patch_pipe(
pipe,
maybe_unet_path,
token: Optional[str] = None,
r: int = 4,
patch_unet=True,
patch_text=True,
patch_ti=True,
idempotent_token=True,
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
):
safeloras=maybe_unet_path
monkeypatch_or_replace_safeloras(pipe, safeloras)
tok_dict = parse_safeloras_embeds(safeloras)
if patch_ti:
apply_learned_embed_in_clip(
tok_dict,
pipe.text_encoder,
pipe.tokenizer,
token=token,
idempotent=idempotent_token,
)
return tok_dict
def lora_convert(model_path, as_half):
"""
Modified version of lora_duffusion.to_ckpt_v2.convert_to_ckpt
"""
assert model_path is not None, "Must provide a model path!"
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
# Convert the UNet model
unet_state_dict = torch.load(unet_path, map_location="cpu")
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {
"model.diffusion_model." + k: v for k, v in unet_state_dict.items()
}
# Convert the VAE model
vae_state_dict = torch.load(vae_path, map_location="cpu")
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Convert the text encoder model
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
}
# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
if as_half:
state_dict = {k: v.half() for k, v in state_dict.items()}
return state_dict
def merge(path_1: str,
path_2: str,
alpha_1: float = 0.5,
):
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
path_1,
).to("cpu")
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)
collapse_lora(loaded_pipeline.unet, alpha_1)
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
monkeypatch_remove_lora(loaded_pipeline.unet)
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
_tmp_output = "./merge.tmp"
loaded_pipeline.save_pretrained(_tmp_output)
state_dict = lora_convert(_tmp_output, as_half=True)
# remove the tmp_output folder
shutil.rmtree(_tmp_output)
keys = sorted(tok_dict.keys())
tok_catted = torch.stack([tok_dict[k] for k in keys])
ret = {
"string_to_token": {"*": torch.tensor(265)},
"string_to_param": {"*": tok_catted},
"name": "",
}
return state_dict, ret |