Spaces:
Build error
Build error
apply styling.
Browse files- app.py +27 -15
- conversion_utils/__init__.py +1 -1
- conversion_utils/text_encoder.py +10 -8
- conversion_utils/unet.py +326 -120
- conversion_utils/utils.py +10 -6
- convert.py +20 -18
- hub_utils/__init__.py +2 -2
- hub_utils/readme.py +2 -2
- hub_utils/repo.py +10 -3
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from convert import run_conversion
|
3 |
-
from hub_utils import save_model_card, push_to_hub
|
4 |
|
|
|
|
|
5 |
|
6 |
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
7 |
DESCRIPTION = """
|
@@ -20,25 +20,37 @@ This Space lets you convert KerasCV Stable Diffusion weights to a format compati
|
|
20 |
Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/dreambooth/train_dreambooth_lora.py#L975) for an example on how you can change the scheduler of an already initialized pipeline.
|
21 |
"""
|
22 |
|
|
|
23 |
def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
|
24 |
if text_encoder_weights == "":
|
25 |
-
text_encoder_weights = None
|
26 |
if unet_weights == "":
|
27 |
-
unet_weights = None
|
28 |
pipeline = run_conversion(text_encoder_weights, unet_weights)
|
29 |
output_path = "kerascv_sd_diffusers_pipeline"
|
30 |
pipeline.save_pretrained(output_path)
|
31 |
-
save_model_card(
|
|
|
|
|
|
|
|
|
|
|
32 |
push_str = push_to_hub(hf_token, output_path, repo_prefix)
|
33 |
return push_str
|
34 |
|
35 |
-
demo = gr.Interface(
|
36 |
-
title="KerasCV Stable Diffusion to Diffusers Stable Diffusion Pipelines 🧨🤗",
|
37 |
-
description=DESCRIPTION,
|
38 |
-
allow_flagging="never",
|
39 |
-
inputs=[gr.Text(max_lines=1, label="your_hf_token"), gr.Text(max_lines=1, label="text_encoder_weights"), gr.Text(max_lines=1, label="unet_weights"), gr.Text(max_lines=1, label="output_repo_prefix")],
|
40 |
-
outputs=[gr.Markdown(label="output")],
|
41 |
-
fn=run,
|
42 |
-
)
|
43 |
|
44 |
-
demo.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
|
|
|
|
2 |
|
3 |
+
from convert import run_conversion
|
4 |
+
from hub_utils import push_to_hub, save_model_card
|
5 |
|
6 |
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
7 |
DESCRIPTION = """
|
|
|
20 |
Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/dreambooth/train_dreambooth_lora.py#L975) for an example on how you can change the scheduler of an already initialized pipeline.
|
21 |
"""
|
22 |
|
23 |
+
|
24 |
def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
|
25 |
if text_encoder_weights == "":
|
26 |
+
text_encoder_weights = None
|
27 |
if unet_weights == "":
|
28 |
+
unet_weights = None
|
29 |
pipeline = run_conversion(text_encoder_weights, unet_weights)
|
30 |
output_path = "kerascv_sd_diffusers_pipeline"
|
31 |
pipeline.save_pretrained(output_path)
|
32 |
+
save_model_card(
|
33 |
+
base_model=PRETRAINED_CKPT,
|
34 |
+
repo_folder=output_path,
|
35 |
+
weight_paths=[text_encoder_weights, unet_weights],
|
36 |
+
repo_prefix=repo_prefix,
|
37 |
+
)
|
38 |
push_str = push_to_hub(hf_token, output_path, repo_prefix)
|
39 |
return push_str
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
demo = gr.Interface(
|
43 |
+
title="KerasCV Stable Diffusion to Diffusers Stable Diffusion Pipelines 🧨🤗",
|
44 |
+
description=DESCRIPTION,
|
45 |
+
allow_flagging="never",
|
46 |
+
inputs=[
|
47 |
+
gr.Text(max_lines=1, label="your_hf_token"),
|
48 |
+
gr.Text(max_lines=1, label="text_encoder_weights"),
|
49 |
+
gr.Text(max_lines=1, label="unet_weights"),
|
50 |
+
gr.Text(max_lines=1, label="output_repo_prefix"),
|
51 |
+
],
|
52 |
+
outputs=[gr.Markdown(label="output")],
|
53 |
+
fn=run,
|
54 |
+
)
|
55 |
+
|
56 |
+
demo.launch()
|
conversion_utils/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
from .text_encoder import populate_text_encoder
|
2 |
from .unet import populate_unet
|
3 |
-
from .utils import run_assertion
|
|
|
1 |
from .text_encoder import populate_text_encoder
|
2 |
from .unet import populate_unet
|
3 |
+
from .utils import run_assertion
|
conversion_utils/text_encoder.py
CHANGED
@@ -1,16 +1,23 @@
|
|
1 |
-
from
|
|
|
2 |
import tensorflow as tf
|
3 |
import torch
|
4 |
-
from
|
5 |
|
6 |
MAX_SEQ_LENGTH = 77
|
7 |
|
|
|
8 |
def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
9 |
"""Populates the state dict from the provided TensorFlow model
|
10 |
(applicable only for the text encoder)."""
|
11 |
text_state_dict = dict()
|
12 |
num_encoder_layers = 0
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
for layer in tf_text_encoder.layers:
|
15 |
# Embeddings.
|
16 |
if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding):
|
@@ -102,9 +109,4 @@ def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Te
|
|
102 |
layer.get_weights()[1]
|
103 |
)
|
104 |
|
105 |
-
|
106 |
-
text_state_dict["text_model.embeddings.position_ids"] = torch.tensor(
|
107 |
-
list(range(MAX_SEQ_LENGTH))
|
108 |
-
).unsqueeze(0)
|
109 |
-
|
110 |
-
return text_state_dict
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
import tensorflow as tf
|
4 |
import torch
|
5 |
+
from keras_cv.models import stable_diffusion
|
6 |
|
7 |
MAX_SEQ_LENGTH = 77
|
8 |
|
9 |
+
|
10 |
def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
11 |
"""Populates the state dict from the provided TensorFlow model
|
12 |
(applicable only for the text encoder)."""
|
13 |
text_state_dict = dict()
|
14 |
num_encoder_layers = 0
|
15 |
|
16 |
+
# Position ids.
|
17 |
+
text_state_dict["text_model.embeddings.position_ids"] = torch.tensor(
|
18 |
+
list(range(MAX_SEQ_LENGTH))
|
19 |
+
).unsqueeze(0)
|
20 |
+
|
21 |
for layer in tf_text_encoder.layers:
|
22 |
# Embeddings.
|
23 |
if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding):
|
|
|
109 |
layer.get_weights()[1]
|
110 |
)
|
111 |
|
112 |
+
return text_state_dict
|
|
|
|
|
|
|
|
|
|
conversion_utils/unet.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
import torch
|
3 |
-
from typing import Dict
|
4 |
from itertools import product
|
|
|
|
|
|
|
|
|
5 |
from keras_cv.models import stable_diffusion
|
6 |
|
7 |
-
|
|
|
|
|
|
|
8 |
"""Populates a Transformer block."""
|
9 |
transformer_dict = dict()
|
10 |
if block_id is not None:
|
@@ -15,36 +19,58 @@ def port_transformer_block(transformer_block: tf.keras.Model, up_down: int, bloc
|
|
15 |
# Norms.
|
16 |
for i in range(1, 4):
|
17 |
if i == 1:
|
18 |
-
norm = transformer_block.norm1
|
19 |
elif i == 2:
|
20 |
norm = transformer_block.norm2
|
21 |
elif i == 3:
|
22 |
norm = transformer_block.norm3
|
23 |
-
transformer_dict[
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
# Attentions.
|
27 |
for i in range(1, 3):
|
28 |
if i == 1:
|
29 |
attn = transformer_block.attn1
|
30 |
else:
|
31 |
attn = transformer_block.attn2
|
32 |
-
transformer_dict[
|
33 |
-
|
34 |
-
|
35 |
-
transformer_dict[
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
for i in range(0, 3, 2):
|
40 |
if i == 0:
|
41 |
layer = transformer_block.geglu.dense
|
42 |
-
transformer_dict[
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
else:
|
45 |
layer = transformer_block.dense
|
46 |
-
transformer_dict[
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
return transformer_dict
|
50 |
|
@@ -54,7 +80,7 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
|
54 |
(applicable only for the UNet)."""
|
55 |
unet_state_dict = dict()
|
56 |
|
57 |
-
timstep_emb = 1
|
58 |
padded_conv = 1
|
59 |
up_block = 0
|
60 |
|
@@ -67,37 +93,66 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
|
67 |
for layer in tf_unet.layers:
|
68 |
# Timstep embedding.
|
69 |
if isinstance(layer, tf.keras.layers.Dense):
|
70 |
-
unet_state_dict[
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
timstep_emb += 1
|
73 |
-
|
74 |
# Padded convs (downsamplers).
|
75 |
-
elif isinstance(
|
|
|
|
|
76 |
if padded_conv == 1:
|
77 |
# Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104
|
78 |
-
unet_state_dict["conv_in.weight"] = torch.from_numpy(
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
elif padded_conv in [2, 3, 4]:
|
81 |
-
unet_state_dict[
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
elif padded_conv == 5:
|
84 |
-
unet_state_dict["conv_out.weight"] = torch.from_numpy(
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
padded_conv += 1
|
88 |
|
89 |
# Upsamplers.
|
90 |
elif isinstance(layer, stable_diffusion.diffusion_model.Upsample):
|
91 |
conv = layer.conv
|
92 |
-
unet_state_dict[
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
up_block += 1
|
95 |
|
96 |
# Output norms.
|
97 |
-
elif isinstance(
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
# All ResBlocks.
|
102 |
elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock):
|
103 |
layer_name = layer.name
|
@@ -105,8 +160,8 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
|
105 |
|
106 |
# Down.
|
107 |
if len(parts) == 2 or int(parts[-1]) < 8:
|
108 |
-
entry_flow = layer.entry_flow
|
109 |
-
embedding_flow = layer.embedding_flow
|
110 |
exit_flow = layer.exit_flow
|
111 |
|
112 |
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
|
@@ -114,72 +169,138 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
|
114 |
|
115 |
# Conv blocks.
|
116 |
first_conv_layer = entry_flow[-1]
|
117 |
-
unet_state_dict[
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
second_conv_layer = exit_flow[-1]
|
120 |
-
unet_state_dict[
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
if hasattr(layer, "residual_projection"):
|
125 |
-
if isinstance(
|
|
|
|
|
|
|
126 |
residual = layer.residual_projection
|
127 |
-
unet_state_dict[
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
# Timestep embedding.
|
131 |
embedding_proj = embedding_flow[-1]
|
132 |
-
unet_state_dict[
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
# Norms.
|
136 |
first_group_norm = entry_flow[0]
|
137 |
-
unet_state_dict[
|
138 |
-
|
|
|
|
|
|
|
|
|
139 |
second_group_norm = exit_flow[0]
|
140 |
-
unet_state_dict[
|
141 |
-
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# Middle.
|
144 |
elif int(parts[-1]) == 8 or int(parts[-1]) == 9:
|
145 |
-
entry_flow = layer.entry_flow
|
146 |
-
embedding_flow = layer.embedding_flow
|
147 |
exit_flow = layer.exit_flow
|
148 |
-
|
149 |
mid_resnet_id = int(parts[-1]) % 2
|
150 |
-
|
151 |
# Conv blocks.
|
152 |
first_conv_layer = entry_flow[-1]
|
153 |
-
unet_state_dict[
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
second_conv_layer = exit_flow[-1]
|
156 |
-
unet_state_dict[
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
if hasattr(layer, "residual_projection"):
|
161 |
-
if isinstance(
|
|
|
|
|
|
|
162 |
residual = layer.residual_projection
|
163 |
-
unet_state_dict[
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
# Timestep embedding.
|
167 |
embedding_proj = embedding_flow[-1]
|
168 |
-
unet_state_dict[
|
169 |
-
|
|
|
|
|
|
|
|
|
170 |
|
171 |
# Norms.
|
172 |
first_group_norm = entry_flow[0]
|
173 |
-
unet_state_dict[
|
174 |
-
|
|
|
|
|
|
|
|
|
175 |
second_group_norm = exit_flow[0]
|
176 |
-
unet_state_dict[
|
177 |
-
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
# Up.
|
180 |
elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks):
|
181 |
-
entry_flow = layer.entry_flow
|
182 |
-
embedding_flow = layer.embedding_flow
|
183 |
exit_flow = layer.exit_flow
|
184 |
|
185 |
up_res_block = up_res_blocks[up_res_block_flag]
|
@@ -188,32 +309,65 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
|
188 |
|
189 |
# Conv blocks.
|
190 |
first_conv_layer = entry_flow[-1]
|
191 |
-
unet_state_dict[
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
second_conv_layer = exit_flow[-1]
|
194 |
-
unet_state_dict[
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
if hasattr(layer, "residual_projection"):
|
199 |
-
if isinstance(
|
|
|
|
|
|
|
200 |
residual = layer.residual_projection
|
201 |
-
unet_state_dict[
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
# Timestep embedding.
|
205 |
embedding_proj = embedding_flow[-1]
|
206 |
-
unet_state_dict[
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
209 |
# Norms.
|
210 |
first_group_norm = entry_flow[0]
|
211 |
-
unet_state_dict[
|
212 |
-
|
|
|
|
|
|
|
|
|
213 |
second_group_norm = exit_flow[0]
|
214 |
-
unet_state_dict[
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
217 |
up_res_block_flag += 1
|
218 |
|
219 |
# All SpatialTransformer blocks.
|
@@ -225,67 +379,119 @@ def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
|
225 |
if len(parts) == 2 or int(parts[-1]) < 6:
|
226 |
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
|
227 |
down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
|
228 |
-
|
229 |
# Convs.
|
230 |
proj1 = layer.proj1
|
231 |
-
unet_state_dict[
|
232 |
-
|
|
|
|
|
|
|
|
|
233 |
proj2 = layer.proj2
|
234 |
-
unet_state_dict[
|
235 |
-
|
|
|
|
|
|
|
|
|
236 |
|
237 |
# Transformer blocks.
|
238 |
transformer_block = layer.transformer_block
|
239 |
-
unet_state_dict.update(
|
|
|
|
|
|
|
|
|
240 |
|
241 |
# Norms.
|
242 |
norm = layer.norm
|
243 |
-
unet_state_dict[
|
244 |
-
|
|
|
|
|
|
|
|
|
245 |
|
246 |
# Middle.
|
247 |
elif int(parts[-1]) == 6:
|
248 |
mid_attention_id = int(parts[-1]) % 2
|
249 |
# Convs.
|
250 |
proj1 = layer.proj1
|
251 |
-
unet_state_dict[
|
252 |
-
|
|
|
|
|
|
|
|
|
253 |
proj2 = layer.proj2
|
254 |
-
unet_state_dict[
|
255 |
-
|
|
|
|
|
|
|
|
|
256 |
|
257 |
# Transformer blocks.
|
258 |
transformer_block = layer.transformer_block
|
259 |
-
unet_state_dict.update(
|
|
|
|
|
|
|
|
|
260 |
|
261 |
# Norms.
|
262 |
norm = layer.norm
|
263 |
-
unet_state_dict[
|
264 |
-
|
|
|
|
|
|
|
|
|
265 |
|
266 |
# Up.
|
267 |
-
elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(
|
268 |
-
|
|
|
|
|
|
|
|
|
269 |
up_block_id = up_spatial_transformer_block[0]
|
270 |
up_attention_id = up_spatial_transformer_block[1]
|
271 |
|
272 |
# Convs.
|
273 |
proj1 = layer.proj1
|
274 |
-
unet_state_dict[
|
275 |
-
|
|
|
|
|
|
|
|
|
276 |
proj2 = layer.proj2
|
277 |
-
unet_state_dict[
|
278 |
-
|
|
|
|
|
|
|
|
|
279 |
|
280 |
# Transformer blocks.
|
281 |
transformer_block = layer.transformer_block
|
282 |
-
unet_state_dict.update(
|
|
|
|
|
|
|
|
|
283 |
|
284 |
# Norms.
|
285 |
norm = layer.norm
|
286 |
-
unet_state_dict[
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
|
|
289 |
up_spatial_transformer_flag += 1
|
290 |
|
291 |
-
return unet_state_dict
|
|
|
|
|
|
|
|
|
1 |
from itertools import product
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import tensorflow as tf
|
5 |
+
import torch
|
6 |
from keras_cv.models import stable_diffusion
|
7 |
|
8 |
+
|
9 |
+
def port_transformer_block(
|
10 |
+
transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int
|
11 |
+
) -> Dict[str, torch.Tensor]:
|
12 |
"""Populates a Transformer block."""
|
13 |
transformer_dict = dict()
|
14 |
if block_id is not None:
|
|
|
19 |
# Norms.
|
20 |
for i in range(1, 4):
|
21 |
if i == 1:
|
22 |
+
norm = transformer_block.norm1
|
23 |
elif i == 2:
|
24 |
norm = transformer_block.norm2
|
25 |
elif i == 3:
|
26 |
norm = transformer_block.norm3
|
27 |
+
transformer_dict[
|
28 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight"
|
29 |
+
] = torch.from_numpy(norm.get_weights()[0])
|
30 |
+
transformer_dict[
|
31 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias"
|
32 |
+
] = torch.from_numpy(norm.get_weights()[1])
|
33 |
+
|
34 |
# Attentions.
|
35 |
for i in range(1, 3):
|
36 |
if i == 1:
|
37 |
attn = transformer_block.attn1
|
38 |
else:
|
39 |
attn = transformer_block.attn2
|
40 |
+
transformer_dict[
|
41 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight"
|
42 |
+
] = torch.from_numpy(attn.to_q.get_weights()[0].transpose())
|
43 |
+
transformer_dict[
|
44 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight"
|
45 |
+
] = torch.from_numpy(attn.to_k.get_weights()[0].transpose())
|
46 |
+
transformer_dict[
|
47 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight"
|
48 |
+
] = torch.from_numpy(attn.to_v.get_weights()[0].transpose())
|
49 |
+
transformer_dict[
|
50 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight"
|
51 |
+
] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose())
|
52 |
+
transformer_dict[
|
53 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias"
|
54 |
+
] = torch.from_numpy(attn.out_proj.get_weights()[1])
|
55 |
+
|
56 |
+
# Dense.
|
57 |
for i in range(0, 3, 2):
|
58 |
if i == 0:
|
59 |
layer = transformer_block.geglu.dense
|
60 |
+
transformer_dict[
|
61 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight"
|
62 |
+
] = torch.from_numpy(layer.get_weights()[0].transpose())
|
63 |
+
transformer_dict[
|
64 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias"
|
65 |
+
] = torch.from_numpy(layer.get_weights()[1])
|
66 |
else:
|
67 |
layer = transformer_block.dense
|
68 |
+
transformer_dict[
|
69 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight"
|
70 |
+
] = torch.from_numpy(layer.get_weights()[0].transpose())
|
71 |
+
transformer_dict[
|
72 |
+
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias"
|
73 |
+
] = torch.from_numpy(layer.get_weights()[1])
|
74 |
|
75 |
return transformer_dict
|
76 |
|
|
|
80 |
(applicable only for the UNet)."""
|
81 |
unet_state_dict = dict()
|
82 |
|
83 |
+
timstep_emb = 1
|
84 |
padded_conv = 1
|
85 |
up_block = 0
|
86 |
|
|
|
93 |
for layer in tf_unet.layers:
|
94 |
# Timstep embedding.
|
95 |
if isinstance(layer, tf.keras.layers.Dense):
|
96 |
+
unet_state_dict[
|
97 |
+
f"time_embedding.linear_{timstep_emb}.weight"
|
98 |
+
] = torch.from_numpy(layer.get_weights()[0].transpose())
|
99 |
+
unet_state_dict[
|
100 |
+
f"time_embedding.linear_{timstep_emb}.bias"
|
101 |
+
] = torch.from_numpy(layer.get_weights()[1])
|
102 |
timstep_emb += 1
|
103 |
+
|
104 |
# Padded convs (downsamplers).
|
105 |
+
elif isinstance(
|
106 |
+
layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D
|
107 |
+
):
|
108 |
if padded_conv == 1:
|
109 |
# Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104
|
110 |
+
unet_state_dict["conv_in.weight"] = torch.from_numpy(
|
111 |
+
layer.get_weights()[0].transpose(3, 2, 0, 1)
|
112 |
+
)
|
113 |
+
unet_state_dict["conv_in.bias"] = torch.from_numpy(
|
114 |
+
layer.get_weights()[1]
|
115 |
+
)
|
116 |
elif padded_conv in [2, 3, 4]:
|
117 |
+
unet_state_dict[
|
118 |
+
f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight"
|
119 |
+
] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
|
120 |
+
unet_state_dict[
|
121 |
+
f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias"
|
122 |
+
] = torch.from_numpy(layer.get_weights()[1])
|
123 |
elif padded_conv == 5:
|
124 |
+
unet_state_dict["conv_out.weight"] = torch.from_numpy(
|
125 |
+
layer.get_weights()[0].transpose(3, 2, 0, 1)
|
126 |
+
)
|
127 |
+
unet_state_dict["conv_out.bias"] = torch.from_numpy(
|
128 |
+
layer.get_weights()[1]
|
129 |
+
)
|
130 |
+
|
131 |
padded_conv += 1
|
132 |
|
133 |
# Upsamplers.
|
134 |
elif isinstance(layer, stable_diffusion.diffusion_model.Upsample):
|
135 |
conv = layer.conv
|
136 |
+
unet_state_dict[
|
137 |
+
f"up_blocks.{up_block}.upsamplers.0.conv.weight"
|
138 |
+
] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1))
|
139 |
+
unet_state_dict[
|
140 |
+
f"up_blocks.{up_block}.upsamplers.0.conv.bias"
|
141 |
+
] = torch.from_numpy(conv.get_weights()[1])
|
142 |
up_block += 1
|
143 |
|
144 |
# Output norms.
|
145 |
+
elif isinstance(
|
146 |
+
layer,
|
147 |
+
stable_diffusion.__internal__.layers.group_normalization.GroupNormalization,
|
148 |
+
):
|
149 |
+
unet_state_dict["conv_norm_out.weight"] = torch.from_numpy(
|
150 |
+
layer.get_weights()[0]
|
151 |
+
)
|
152 |
+
unet_state_dict["conv_norm_out.bias"] = torch.from_numpy(
|
153 |
+
layer.get_weights()[1]
|
154 |
+
)
|
155 |
+
|
156 |
# All ResBlocks.
|
157 |
elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock):
|
158 |
layer_name = layer.name
|
|
|
160 |
|
161 |
# Down.
|
162 |
if len(parts) == 2 or int(parts[-1]) < 8:
|
163 |
+
entry_flow = layer.entry_flow
|
164 |
+
embedding_flow = layer.embedding_flow
|
165 |
exit_flow = layer.exit_flow
|
166 |
|
167 |
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
|
|
|
169 |
|
170 |
# Conv blocks.
|
171 |
first_conv_layer = entry_flow[-1]
|
172 |
+
unet_state_dict[
|
173 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight"
|
174 |
+
] = torch.from_numpy(
|
175 |
+
first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
|
176 |
+
)
|
177 |
+
unet_state_dict[
|
178 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias"
|
179 |
+
] = torch.from_numpy(first_conv_layer.get_weights()[1])
|
180 |
second_conv_layer = exit_flow[-1]
|
181 |
+
unet_state_dict[
|
182 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight"
|
183 |
+
] = torch.from_numpy(
|
184 |
+
second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
|
185 |
+
)
|
186 |
+
unet_state_dict[
|
187 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias"
|
188 |
+
] = torch.from_numpy(second_conv_layer.get_weights()[1])
|
189 |
+
|
190 |
+
# Residual blocks.
|
191 |
if hasattr(layer, "residual_projection"):
|
192 |
+
if isinstance(
|
193 |
+
layer.residual_projection,
|
194 |
+
stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
|
195 |
+
):
|
196 |
residual = layer.residual_projection
|
197 |
+
unet_state_dict[
|
198 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight"
|
199 |
+
] = torch.from_numpy(
|
200 |
+
residual.get_weights()[0].transpose(3, 2, 0, 1)
|
201 |
+
)
|
202 |
+
unet_state_dict[
|
203 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias"
|
204 |
+
] = torch.from_numpy(residual.get_weights()[1])
|
205 |
|
206 |
# Timestep embedding.
|
207 |
embedding_proj = embedding_flow[-1]
|
208 |
+
unet_state_dict[
|
209 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight"
|
210 |
+
] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
|
211 |
+
unet_state_dict[
|
212 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias"
|
213 |
+
] = torch.from_numpy(embedding_proj.get_weights()[1])
|
214 |
+
|
215 |
# Norms.
|
216 |
first_group_norm = entry_flow[0]
|
217 |
+
unet_state_dict[
|
218 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight"
|
219 |
+
] = torch.from_numpy(first_group_norm.get_weights()[0])
|
220 |
+
unet_state_dict[
|
221 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias"
|
222 |
+
] = torch.from_numpy(first_group_norm.get_weights()[1])
|
223 |
second_group_norm = exit_flow[0]
|
224 |
+
unet_state_dict[
|
225 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight"
|
226 |
+
] = torch.from_numpy(second_group_norm.get_weights()[0])
|
227 |
+
unet_state_dict[
|
228 |
+
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias"
|
229 |
+
] = torch.from_numpy(second_group_norm.get_weights()[1])
|
230 |
|
231 |
# Middle.
|
232 |
elif int(parts[-1]) == 8 or int(parts[-1]) == 9:
|
233 |
+
entry_flow = layer.entry_flow
|
234 |
+
embedding_flow = layer.embedding_flow
|
235 |
exit_flow = layer.exit_flow
|
236 |
+
|
237 |
mid_resnet_id = int(parts[-1]) % 2
|
238 |
+
|
239 |
# Conv blocks.
|
240 |
first_conv_layer = entry_flow[-1]
|
241 |
+
unet_state_dict[
|
242 |
+
f"mid_block.resnets.{mid_resnet_id}.conv1.weight"
|
243 |
+
] = torch.from_numpy(
|
244 |
+
first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
|
245 |
+
)
|
246 |
+
unet_state_dict[
|
247 |
+
f"mid_block.resnets.{mid_resnet_id}.conv1.bias"
|
248 |
+
] = torch.from_numpy(first_conv_layer.get_weights()[1])
|
249 |
second_conv_layer = exit_flow[-1]
|
250 |
+
unet_state_dict[
|
251 |
+
f"mid_block.resnets.{mid_resnet_id}.conv2.weight"
|
252 |
+
] = torch.from_numpy(
|
253 |
+
second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
|
254 |
+
)
|
255 |
+
unet_state_dict[
|
256 |
+
f"mid_block.resnets.{mid_resnet_id}.conv2.bias"
|
257 |
+
] = torch.from_numpy(second_conv_layer.get_weights()[1])
|
258 |
+
|
259 |
+
# Residual blocks.
|
260 |
if hasattr(layer, "residual_projection"):
|
261 |
+
if isinstance(
|
262 |
+
layer.residual_projection,
|
263 |
+
stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
|
264 |
+
):
|
265 |
residual = layer.residual_projection
|
266 |
+
unet_state_dict[
|
267 |
+
f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight"
|
268 |
+
] = torch.from_numpy(
|
269 |
+
residual.get_weights()[0].transpose(3, 2, 0, 1)
|
270 |
+
)
|
271 |
+
unet_state_dict[
|
272 |
+
f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias"
|
273 |
+
] = torch.from_numpy(residual.get_weights()[1])
|
274 |
|
275 |
# Timestep embedding.
|
276 |
embedding_proj = embedding_flow[-1]
|
277 |
+
unet_state_dict[
|
278 |
+
f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight"
|
279 |
+
] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
|
280 |
+
unet_state_dict[
|
281 |
+
f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias"
|
282 |
+
] = torch.from_numpy(embedding_proj.get_weights()[1])
|
283 |
|
284 |
# Norms.
|
285 |
first_group_norm = entry_flow[0]
|
286 |
+
unet_state_dict[
|
287 |
+
f"mid_block.resnets.{mid_resnet_id}.norm1.weight"
|
288 |
+
] = torch.from_numpy(first_group_norm.get_weights()[0])
|
289 |
+
unet_state_dict[
|
290 |
+
f"mid_block.resnets.{mid_resnet_id}.norm1.bias"
|
291 |
+
] = torch.from_numpy(first_group_norm.get_weights()[1])
|
292 |
second_group_norm = exit_flow[0]
|
293 |
+
unet_state_dict[
|
294 |
+
f"mid_block.resnets.{mid_resnet_id}.norm2.weight"
|
295 |
+
] = torch.from_numpy(second_group_norm.get_weights()[0])
|
296 |
+
unet_state_dict[
|
297 |
+
f"mid_block.resnets.{mid_resnet_id}.norm2.bias"
|
298 |
+
] = torch.from_numpy(second_group_norm.get_weights()[1])
|
299 |
|
300 |
+
# Up.
|
301 |
elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks):
|
302 |
+
entry_flow = layer.entry_flow
|
303 |
+
embedding_flow = layer.embedding_flow
|
304 |
exit_flow = layer.exit_flow
|
305 |
|
306 |
up_res_block = up_res_blocks[up_res_block_flag]
|
|
|
309 |
|
310 |
# Conv blocks.
|
311 |
first_conv_layer = entry_flow[-1]
|
312 |
+
unet_state_dict[
|
313 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight"
|
314 |
+
] = torch.from_numpy(
|
315 |
+
first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
|
316 |
+
)
|
317 |
+
unet_state_dict[
|
318 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias"
|
319 |
+
] = torch.from_numpy(first_conv_layer.get_weights()[1])
|
320 |
second_conv_layer = exit_flow[-1]
|
321 |
+
unet_state_dict[
|
322 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight"
|
323 |
+
] = torch.from_numpy(
|
324 |
+
second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
|
325 |
+
)
|
326 |
+
unet_state_dict[
|
327 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias"
|
328 |
+
] = torch.from_numpy(second_conv_layer.get_weights()[1])
|
329 |
+
|
330 |
+
# Residual blocks.
|
331 |
if hasattr(layer, "residual_projection"):
|
332 |
+
if isinstance(
|
333 |
+
layer.residual_projection,
|
334 |
+
stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
|
335 |
+
):
|
336 |
residual = layer.residual_projection
|
337 |
+
unet_state_dict[
|
338 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight"
|
339 |
+
] = torch.from_numpy(
|
340 |
+
residual.get_weights()[0].transpose(3, 2, 0, 1)
|
341 |
+
)
|
342 |
+
unet_state_dict[
|
343 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias"
|
344 |
+
] = torch.from_numpy(residual.get_weights()[1])
|
345 |
|
346 |
# Timestep embedding.
|
347 |
embedding_proj = embedding_flow[-1]
|
348 |
+
unet_state_dict[
|
349 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight"
|
350 |
+
] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
|
351 |
+
unet_state_dict[
|
352 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias"
|
353 |
+
] = torch.from_numpy(embedding_proj.get_weights()[1])
|
354 |
+
|
355 |
# Norms.
|
356 |
first_group_norm = entry_flow[0]
|
357 |
+
unet_state_dict[
|
358 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight"
|
359 |
+
] = torch.from_numpy(first_group_norm.get_weights()[0])
|
360 |
+
unet_state_dict[
|
361 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias"
|
362 |
+
] = torch.from_numpy(first_group_norm.get_weights()[1])
|
363 |
second_group_norm = exit_flow[0]
|
364 |
+
unet_state_dict[
|
365 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight"
|
366 |
+
] = torch.from_numpy(second_group_norm.get_weights()[0])
|
367 |
+
unet_state_dict[
|
368 |
+
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias"
|
369 |
+
] = torch.from_numpy(second_group_norm.get_weights()[1])
|
370 |
+
|
371 |
up_res_block_flag += 1
|
372 |
|
373 |
# All SpatialTransformer blocks.
|
|
|
379 |
if len(parts) == 2 or int(parts[-1]) < 6:
|
380 |
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
|
381 |
down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
|
382 |
+
|
383 |
# Convs.
|
384 |
proj1 = layer.proj1
|
385 |
+
unet_state_dict[
|
386 |
+
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight"
|
387 |
+
] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
|
388 |
+
unet_state_dict[
|
389 |
+
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias"
|
390 |
+
] = torch.from_numpy(proj1.get_weights()[1])
|
391 |
proj2 = layer.proj2
|
392 |
+
unet_state_dict[
|
393 |
+
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight"
|
394 |
+
] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
|
395 |
+
unet_state_dict[
|
396 |
+
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias"
|
397 |
+
] = torch.from_numpy(proj2.get_weights()[1])
|
398 |
|
399 |
# Transformer blocks.
|
400 |
transformer_block = layer.transformer_block
|
401 |
+
unet_state_dict.update(
|
402 |
+
port_transformer_block(
|
403 |
+
transformer_block, "down", down_block_id, down_attention_id
|
404 |
+
)
|
405 |
+
)
|
406 |
|
407 |
# Norms.
|
408 |
norm = layer.norm
|
409 |
+
unet_state_dict[
|
410 |
+
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight"
|
411 |
+
] = torch.from_numpy(norm.get_weights()[0])
|
412 |
+
unet_state_dict[
|
413 |
+
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias"
|
414 |
+
] = torch.from_numpy(norm.get_weights()[1])
|
415 |
|
416 |
# Middle.
|
417 |
elif int(parts[-1]) == 6:
|
418 |
mid_attention_id = int(parts[-1]) % 2
|
419 |
# Convs.
|
420 |
proj1 = layer.proj1
|
421 |
+
unet_state_dict[
|
422 |
+
f"mid_block.attentions.{mid_attention_id}.proj_in.weight"
|
423 |
+
] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
|
424 |
+
unet_state_dict[
|
425 |
+
f"mid_block.attentions.{mid_attention_id}.proj_in.bias"
|
426 |
+
] = torch.from_numpy(proj1.get_weights()[1])
|
427 |
proj2 = layer.proj2
|
428 |
+
unet_state_dict[
|
429 |
+
f"mid_block.attentions.{mid_resnet_id}.proj_out.weight"
|
430 |
+
] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
|
431 |
+
unet_state_dict[
|
432 |
+
f"mid_block.attentions.{mid_attention_id}.proj_out.bias"
|
433 |
+
] = torch.from_numpy(proj2.get_weights()[1])
|
434 |
|
435 |
# Transformer blocks.
|
436 |
transformer_block = layer.transformer_block
|
437 |
+
unet_state_dict.update(
|
438 |
+
port_transformer_block(
|
439 |
+
transformer_block, "mid", None, mid_attention_id
|
440 |
+
)
|
441 |
+
)
|
442 |
|
443 |
# Norms.
|
444 |
norm = layer.norm
|
445 |
+
unet_state_dict[
|
446 |
+
f"mid_block.attentions.{mid_attention_id}.norm.weight"
|
447 |
+
] = torch.from_numpy(norm.get_weights()[0])
|
448 |
+
unet_state_dict[
|
449 |
+
f"mid_block.attentions.{mid_attention_id}.norm.bias"
|
450 |
+
] = torch.from_numpy(norm.get_weights()[1])
|
451 |
|
452 |
# Up.
|
453 |
+
elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(
|
454 |
+
up_spatial_transformer_blocks
|
455 |
+
):
|
456 |
+
up_spatial_transformer_block = up_spatial_transformer_blocks[
|
457 |
+
up_spatial_transformer_flag
|
458 |
+
]
|
459 |
up_block_id = up_spatial_transformer_block[0]
|
460 |
up_attention_id = up_spatial_transformer_block[1]
|
461 |
|
462 |
# Convs.
|
463 |
proj1 = layer.proj1
|
464 |
+
unet_state_dict[
|
465 |
+
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight"
|
466 |
+
] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
|
467 |
+
unet_state_dict[
|
468 |
+
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias"
|
469 |
+
] = torch.from_numpy(proj1.get_weights()[1])
|
470 |
proj2 = layer.proj2
|
471 |
+
unet_state_dict[
|
472 |
+
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight"
|
473 |
+
] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
|
474 |
+
unet_state_dict[
|
475 |
+
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias"
|
476 |
+
] = torch.from_numpy(proj2.get_weights()[1])
|
477 |
|
478 |
# Transformer blocks.
|
479 |
transformer_block = layer.transformer_block
|
480 |
+
unet_state_dict.update(
|
481 |
+
port_transformer_block(
|
482 |
+
transformer_block, "up", up_block_id, up_attention_id
|
483 |
+
)
|
484 |
+
)
|
485 |
|
486 |
# Norms.
|
487 |
norm = layer.norm
|
488 |
+
unet_state_dict[
|
489 |
+
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight"
|
490 |
+
] = torch.from_numpy(norm.get_weights()[0])
|
491 |
+
unet_state_dict[
|
492 |
+
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias"
|
493 |
+
] = torch.from_numpy(norm.get_weights()[1])
|
494 |
+
|
495 |
up_spatial_transformer_flag += 1
|
496 |
|
497 |
+
return unet_state_dict
|
conversion_utils/utils.py
CHANGED
@@ -1,15 +1,19 @@
|
|
|
|
1 |
|
2 |
import numpy as np
|
3 |
-
import torch
|
4 |
-
from typing import Dict
|
5 |
|
6 |
|
7 |
-
def run_assertion(
|
|
|
|
|
|
|
8 |
for k in orig_pt_state_dict:
|
9 |
try:
|
10 |
np.testing.assert_allclose(
|
11 |
-
orig_pt_state_dict[k].numpy(),
|
12 |
-
pt_state_dict_from_tf[k].numpy()
|
13 |
)
|
14 |
except:
|
15 |
-
raise ValueError(
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
|
3 |
import numpy as np
|
4 |
+
import torch
|
|
|
5 |
|
6 |
|
7 |
+
def run_assertion(
|
8 |
+
orig_pt_state_dict: Dict[str, torch.Tensor],
|
9 |
+
pt_state_dict_from_tf: Dict[str, torch.Tensor],
|
10 |
+
):
|
11 |
for k in orig_pt_state_dict:
|
12 |
try:
|
13 |
np.testing.assert_allclose(
|
14 |
+
orig_pt_state_dict[k].numpy(), pt_state_dict_from_tf[k].numpy()
|
|
|
15 |
)
|
16 |
except:
|
17 |
+
raise ValueError(
|
18 |
+
"There are problems in the parameter population process. Cannot proceed :("
|
19 |
+
)
|
convert.py
CHANGED
@@ -1,26 +1,25 @@
|
|
1 |
-
from conversion_utils import populate_text_encoder, populate_unet, run_assertion
|
2 |
-
|
3 |
-
from diffusers import (
|
4 |
-
AutoencoderKL,
|
5 |
-
StableDiffusionPipeline,
|
6 |
-
UNet2DConditionModel,
|
7 |
-
)
|
8 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
9 |
-
from transformers import CLIPTextModel
|
10 |
import keras_cv
|
11 |
import tensorflow as tf
|
|
|
|
|
|
|
|
|
|
|
12 |
|
|
|
|
|
13 |
|
14 |
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
15 |
REVISION = None
|
16 |
NON_EMA_REVISION = None
|
17 |
IMG_HEIGHT = IMG_WIDTH = 512
|
18 |
|
|
|
19 |
def initialize_pt_models():
|
20 |
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
|
21 |
their pre-trained weights."""
|
22 |
pt_text_encoder = CLIPTextModel.from_pretrained(
|
23 |
-
|
24 |
)
|
25 |
pt_vae = AutoencoderKL.from_pretrained(
|
26 |
PRETRAINED_CKPT, subfolder="vae", revision=REVISION
|
@@ -34,14 +33,17 @@ def initialize_pt_models():
|
|
34 |
|
35 |
return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
|
36 |
|
|
|
37 |
def initialize_tf_models():
|
38 |
"""Initializes the separate models of Stable Diffusion from KerasCV and downloads
|
39 |
their pre-trained weights."""
|
40 |
-
tf_sd_model = keras_cv.models.StableDiffusion(
|
41 |
-
|
|
|
|
|
42 |
|
43 |
-
tf_text_encoder = tf_sd_model.text_encoder
|
44 |
-
tf_vae = tf_sd_model.image_encoder
|
45 |
tf_unet = tf_sd_model.diffusion_model
|
46 |
return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
|
47 |
|
@@ -50,7 +52,7 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
|
|
50 |
pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
|
51 |
tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models()
|
52 |
print("Pre-trained model weights downloaded.")
|
53 |
-
|
54 |
if text_encoder_weights is not None:
|
55 |
print("Loading fine-tuned text encoder weights.")
|
56 |
text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
|
@@ -72,7 +74,9 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
|
|
72 |
unet_state_dict_from_pt = pt_text_encoder.state_dict()
|
73 |
run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
|
74 |
|
75 |
-
print(
|
|
|
|
|
76 |
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
|
77 |
pt_unet.load_state_dict(unet_state_dict_from_tf)
|
78 |
|
@@ -86,5 +90,3 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
|
|
86 |
revision=None,
|
87 |
)
|
88 |
return pipeline
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import keras_cv
|
2 |
import tensorflow as tf
|
3 |
+
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
|
4 |
+
UNet2DConditionModel)
|
5 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
6 |
+
StableDiffusionSafetyChecker
|
7 |
+
from transformers import CLIPTextModel
|
8 |
|
9 |
+
from conversion_utils import (populate_text_encoder, populate_unet,
|
10 |
+
run_assertion)
|
11 |
|
12 |
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
13 |
REVISION = None
|
14 |
NON_EMA_REVISION = None
|
15 |
IMG_HEIGHT = IMG_WIDTH = 512
|
16 |
|
17 |
+
|
18 |
def initialize_pt_models():
|
19 |
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
|
20 |
their pre-trained weights."""
|
21 |
pt_text_encoder = CLIPTextModel.from_pretrained(
|
22 |
+
PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
|
23 |
)
|
24 |
pt_vae = AutoencoderKL.from_pretrained(
|
25 |
PRETRAINED_CKPT, subfolder="vae", revision=REVISION
|
|
|
33 |
|
34 |
return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
|
35 |
|
36 |
+
|
37 |
def initialize_tf_models():
|
38 |
"""Initializes the separate models of Stable Diffusion from KerasCV and downloads
|
39 |
their pre-trained weights."""
|
40 |
+
tf_sd_model = keras_cv.models.StableDiffusion(
|
41 |
+
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
|
42 |
+
)
|
43 |
+
_ = tf_sd_model.text_to_image("Cartoon") # To download the weights.
|
44 |
|
45 |
+
tf_text_encoder = tf_sd_model.text_encoder
|
46 |
+
tf_vae = tf_sd_model.image_encoder
|
47 |
tf_unet = tf_sd_model.diffusion_model
|
48 |
return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
|
49 |
|
|
|
52 |
pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
|
53 |
tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models()
|
54 |
print("Pre-trained model weights downloaded.")
|
55 |
+
|
56 |
if text_encoder_weights is not None:
|
57 |
print("Loading fine-tuned text encoder weights.")
|
58 |
text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
|
|
|
74 |
unet_state_dict_from_pt = pt_text_encoder.state_dict()
|
75 |
run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
|
76 |
|
77 |
+
print(
|
78 |
+
"Assertions successful, populating the converted parameters into the diffusers models..."
|
79 |
+
)
|
80 |
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
|
81 |
pt_unet.load_state_dict(unet_state_dict_from_tf)
|
82 |
|
|
|
90 |
revision=None,
|
91 |
)
|
92 |
return pipeline
|
|
|
|
hub_utils/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
from .readme import save_model_card
|
2 |
-
from .repo import push_to_hub
|
|
|
1 |
+
from .readme import save_model_card
|
2 |
+
from .repo import push_to_hub
|
hub_utils/readme.py
CHANGED
@@ -23,7 +23,7 @@ The pipeline contained in this repository was created using [this Space](https:/
|
|
23 |
"""
|
24 |
|
25 |
if weight_paths is not None:
|
26 |
-
model_card +=
|
27 |
|
28 |
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
29 |
-
f.write(yaml + model_card)
|
|
|
23 |
"""
|
24 |
|
25 |
if weight_paths is not None:
|
26 |
+
model_card += "Following weight paths (KerasCV) were used: {weight_paths}"
|
27 |
|
28 |
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
29 |
+
f.write(yaml + model_card)
|
hub_utils/repo.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from huggingface_hub import HfApi, create_repo
|
2 |
|
|
|
3 |
def push_to_hub(hf_token: str, push_dir: str, repo_prefix: None) -> str:
|
4 |
try:
|
5 |
if hf_token == "":
|
@@ -7,9 +8,15 @@ def push_to_hub(hf_token: str, push_dir: str, repo_prefix: None) -> str:
|
|
7 |
else:
|
8 |
hf_api = HfApi(token=hf_token)
|
9 |
user = hf_api.whoami()["name"]
|
10 |
-
repo_id =
|
|
|
|
|
|
|
|
|
11 |
_ = create_repo(repo_id=repo_id, token=hf_token)
|
12 |
-
url = hf_api.upload_folder(
|
|
|
|
|
13 |
return f"Model successfully pushed: [{url}]({url})"
|
14 |
except Exception as e:
|
15 |
-
return f"{e}"
|
|
|
1 |
from huggingface_hub import HfApi, create_repo
|
2 |
|
3 |
+
|
4 |
def push_to_hub(hf_token: str, push_dir: str, repo_prefix: None) -> str:
|
5 |
try:
|
6 |
if hf_token == "":
|
|
|
8 |
else:
|
9 |
hf_api = HfApi(token=hf_token)
|
10 |
user = hf_api.whoami()["name"]
|
11 |
+
repo_id = (
|
12 |
+
f"{user}/{push_dir}"
|
13 |
+
if repo_prefix == ""
|
14 |
+
else f"{user}/{repo_prefix}-{push_dir}"
|
15 |
+
)
|
16 |
_ = create_repo(repo_id=repo_id, token=hf_token)
|
17 |
+
url = hf_api.upload_folder(
|
18 |
+
folder_path=push_dir, repo_id=repo_id, exist_ok=True
|
19 |
+
)
|
20 |
return f"Model successfully pushed: [{url}]({url})"
|
21 |
except Exception as e:
|
22 |
+
return f"{e}"
|