import tensorflow as tf import torch from typing import Dict from itertools import product from keras_cv.models import stable_diffusion def port_transformer_block(transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int) -> Dict[str, torch.Tensor]: """Populates a Transformer block.""" transformer_dict = dict() if block_id is not None: prefix = f"{up_down}_blocks.{block_id}" else: prefix = "mid_block" # Norms. for i in range(1, 4): if i == 1: norm = transformer_block.norm1 elif i == 2: norm = transformer_block.norm2 elif i == 3: norm = transformer_block.norm3 transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight"] = torch.from_numpy(norm.get_weights()[0]) transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias"] = torch.from_numpy(norm.get_weights()[1]) # Attentions. for i in range(1, 3): if i == 1: attn = transformer_block.attn1 else: attn = transformer_block.attn2 transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight"] = torch.from_numpy(attn.to_q.get_weights()[0].transpose()) transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight"] = torch.from_numpy(attn.to_k.get_weights()[0].transpose()) transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight"] = torch.from_numpy(attn.to_v.get_weights()[0].transpose()) transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight"] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose()) transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias"] = torch.from_numpy(attn.out_proj.get_weights()[1]) # Dense. for i in range(0, 3, 2): if i == 0: layer = transformer_block.geglu.dense transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight"] = torch.from_numpy(layer.get_weights()[0].transpose()) transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias"] = torch.from_numpy(layer.get_weights()[1]) else: layer = transformer_block.dense transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose()) transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias"] = torch.from_numpy(layer.get_weights()[1]) return transformer_dict def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]: """Populates the state dict from the provided TensorFlow model (applicable only for the UNet).""" unet_state_dict = dict() timstep_emb = 1 padded_conv = 1 up_block = 0 up_res_blocks = list(product([0, 1, 2, 3], [0, 1, 2])) up_res_block_flag = 0 up_spatial_transformer_blocks = list(product([1, 2, 3], [0, 1, 2])) up_spatial_transformer_flag = 0 for layer in tf_unet.layers: # Timstep embedding. if isinstance(layer, tf.keras.layers.Dense): unet_state_dict[f"time_embedding.linear_{timstep_emb}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose()) unet_state_dict[f"time_embedding.linear_{timstep_emb}.bias"] = torch.from_numpy(layer.get_weights()[1]) timstep_emb += 1 # Padded convs (downsamplers). elif isinstance(layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D): if padded_conv == 1: # Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104 unet_state_dict["conv_in.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict["conv_in.bias"] = torch.from_numpy(layer.get_weights()[1]) elif padded_conv in [2, 3, 4]: unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias"] = torch.from_numpy(layer.get_weights()[1]) elif padded_conv == 5: unet_state_dict["conv_out.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict["conv_out.bias"] = torch.from_numpy(layer.get_weights()[1]) padded_conv += 1 # Upsamplers. elif isinstance(layer, stable_diffusion.diffusion_model.Upsample): conv = layer.conv unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.weight"] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.bias"] = torch.from_numpy(conv.get_weights()[1]) up_block += 1 # Output norms. elif isinstance(layer, stable_diffusion.__internal__.layers.group_normalization.GroupNormalization): unet_state_dict["conv_norm_out.weight"] = torch.from_numpy(layer.get_weights()[0]) unet_state_dict["conv_norm_out.bias"] = torch.from_numpy(layer.get_weights()[1]) # All ResBlocks. elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock): layer_name = layer.name parts = layer_name.split("_") # Down. if len(parts) == 2 or int(parts[-1]) < 8: entry_flow = layer.entry_flow embedding_flow = layer.embedding_flow exit_flow = layer.exit_flow down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2 down_resnet_id = 0 if len(parts) == 2 else int(parts[-1]) % 2 # Conv blocks. first_conv_layer = entry_flow[-1] unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1]) second_conv_layer = exit_flow[-1] unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1]) # Residual blocks. if hasattr(layer, "residual_projection"): if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D): residual = layer.residual_projection unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1]) # Timestep embedding. embedding_proj = embedding_flow[-1] unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1]) # Norms. first_group_norm = entry_flow[0] unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0]) unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1]) second_group_norm = exit_flow[0] unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0]) unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1]) # Middle. elif int(parts[-1]) == 8 or int(parts[-1]) == 9: entry_flow = layer.entry_flow embedding_flow = layer.embedding_flow exit_flow = layer.exit_flow mid_resnet_id = int(parts[-1]) % 2 # Conv blocks. first_conv_layer = entry_flow[-1] unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1]) second_conv_layer = exit_flow[-1] unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1]) # Residual blocks. if hasattr(layer, "residual_projection"): if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D): residual = layer.residual_projection unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1]) # Timestep embedding. embedding_proj = embedding_flow[-1] unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1]) # Norms. first_group_norm = entry_flow[0] unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0]) unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1]) second_group_norm = exit_flow[0] unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0]) unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1]) # Up. elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks): entry_flow = layer.entry_flow embedding_flow = layer.embedding_flow exit_flow = layer.exit_flow up_res_block = up_res_blocks[up_res_block_flag] up_block_id = up_res_block[0] up_resnet_id = up_res_block[1] # Conv blocks. first_conv_layer = entry_flow[-1] unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1]) second_conv_layer = exit_flow[-1] unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1]) # Residual blocks. if hasattr(layer, "residual_projection"): if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D): residual = layer.residual_projection unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1]) # Timestep embedding. embedding_proj = embedding_flow[-1] unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose()) unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1]) # Norms. first_group_norm = entry_flow[0] unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0]) unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1]) second_group_norm = exit_flow[0] unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0]) unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1]) up_res_block_flag += 1 # All SpatialTransformer blocks. elif isinstance(layer, stable_diffusion.diffusion_model.SpatialTransformer): layer_name = layer.name parts = layer_name.split("_") # Down. if len(parts) == 2 or int(parts[-1]) < 6: down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2 down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2 # Convs. proj1 = layer.proj1 unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1]) proj2 = layer.proj2 unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1]) # Transformer blocks. transformer_block = layer.transformer_block unet_state_dict.update(port_transformer_block(transformer_block, "down", down_block_id, down_attention_id)) # Norms. norm = layer.norm unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0]) unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1]) # Middle. elif int(parts[-1]) == 6: mid_attention_id = int(parts[-1]) % 2 # Convs. proj1 = layer.proj1 unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1]) proj2 = layer.proj2 unet_state_dict[f"mid_block.attentions.{mid_resnet_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1]) # Transformer blocks. transformer_block = layer.transformer_block unet_state_dict.update(port_transformer_block(transformer_block, "mid", None, mid_attention_id)) # Norms. norm = layer.norm unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0]) unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1]) # Up. elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(up_spatial_transformer_blocks): up_spatial_transformer_block = up_spatial_transformer_blocks[up_spatial_transformer_flag] up_block_id = up_spatial_transformer_block[0] up_attention_id = up_spatial_transformer_block[1] # Convs. proj1 = layer.proj1 unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1]) proj2 = layer.proj2 unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1)) unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1]) # Transformer blocks. transformer_block = layer.transformer_block unet_state_dict.update(port_transformer_block(transformer_block, "up", up_block_id, up_attention_id)) # Norms. norm = layer.norm unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0]) unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1]) up_spatial_transformer_flag += 1 return unet_state_dict