|
import argparse |
|
from contextlib import nullcontext |
|
|
|
import safetensors.torch |
|
import torch |
|
from accelerate import init_empty_weights |
|
from huggingface_hub import hf_hub_download |
|
|
|
from diffusers import AutoencoderKL, FluxTransformer2DModel |
|
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint |
|
from diffusers.utils.import_utils import is_accelerate_available |
|
|
|
|
|
""" |
|
# Transformer |
|
|
|
python scripts/convert_flux_to_diffusers.py \ |
|
--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \ |
|
--filename "flux1-schnell.sft" |
|
--output_path "flux-schnell" \ |
|
--transformer |
|
""" |
|
|
|
""" |
|
# VAE |
|
|
|
python scripts/convert_flux_to_diffusers.py \ |
|
--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \ |
|
--filename "ae.sft" |
|
--output_path "flux-schnell" \ |
|
--vae |
|
""" |
|
|
|
CTX = init_empty_weights if is_accelerate_available else nullcontext |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--original_state_dict_repo_id", default=None, type=str) |
|
parser.add_argument("--filename", default="flux.safetensors", type=str) |
|
parser.add_argument("--checkpoint_path", default=None, type=str) |
|
parser.add_argument("--vae", action="store_true") |
|
parser.add_argument("--transformer", action="store_true") |
|
parser.add_argument("--output_path", type=str) |
|
parser.add_argument("--dtype", type=str, default="bf16") |
|
|
|
args = parser.parse_args() |
|
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 |
|
|
|
|
|
def load_original_checkpoint(args): |
|
if args.original_state_dict_repo_id is not None: |
|
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) |
|
elif args.checkpoint_path is not None: |
|
ckpt_path = args.checkpoint_path |
|
else: |
|
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") |
|
|
|
original_state_dict = safetensors.torch.load_file(ckpt_path) |
|
return original_state_dict |
|
|
|
|
|
|
|
|
|
def swap_scale_shift(weight): |
|
shift, scale = weight.chunk(2, dim=0) |
|
new_weight = torch.cat([scale, shift], dim=0) |
|
return new_weight |
|
|
|
|
|
def convert_flux_transformer_checkpoint_to_diffusers( |
|
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 |
|
): |
|
converted_state_dict = {} |
|
|
|
|
|
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( |
|
"time_in.in_layer.weight" |
|
) |
|
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( |
|
"time_in.in_layer.bias" |
|
) |
|
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( |
|
"time_in.out_layer.weight" |
|
) |
|
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( |
|
"time_in.out_layer.bias" |
|
) |
|
|
|
|
|
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop( |
|
"vector_in.in_layer.weight" |
|
) |
|
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop( |
|
"vector_in.in_layer.bias" |
|
) |
|
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop( |
|
"vector_in.out_layer.weight" |
|
) |
|
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop( |
|
"vector_in.out_layer.bias" |
|
) |
|
|
|
|
|
has_guidance = any("guidance" in k for k in original_state_dict) |
|
if has_guidance: |
|
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop( |
|
"guidance_in.in_layer.weight" |
|
) |
|
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop( |
|
"guidance_in.in_layer.bias" |
|
) |
|
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop( |
|
"guidance_in.out_layer.weight" |
|
) |
|
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop( |
|
"guidance_in.out_layer.bias" |
|
) |
|
|
|
|
|
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight") |
|
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias") |
|
|
|
|
|
converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight") |
|
converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias") |
|
|
|
|
|
for i in range(num_layers): |
|
block_prefix = f"transformer_blocks.{i}." |
|
|
|
|
|
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_mod.lin.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_mod.lin.bias" |
|
) |
|
|
|
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_mod.lin.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_mod.lin.bias" |
|
) |
|
|
|
sample_q, sample_k, sample_v = torch.chunk( |
|
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0 |
|
) |
|
context_q, context_k, context_v = torch.chunk( |
|
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 |
|
) |
|
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( |
|
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 |
|
) |
|
context_q_bias, context_k_bias, context_v_bias = torch.chunk( |
|
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 |
|
) |
|
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) |
|
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) |
|
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) |
|
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) |
|
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) |
|
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) |
|
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) |
|
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) |
|
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) |
|
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) |
|
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) |
|
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) |
|
|
|
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_attn.norm.query_norm.scale" |
|
) |
|
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_attn.norm.key_norm.scale" |
|
) |
|
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_attn.norm.query_norm.scale" |
|
) |
|
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_attn.norm.key_norm.scale" |
|
) |
|
|
|
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_mlp.0.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_mlp.0.bias" |
|
) |
|
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_mlp.2.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_mlp.2.bias" |
|
) |
|
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_mlp.0.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_mlp.0.bias" |
|
) |
|
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_mlp.2.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_mlp.2.bias" |
|
) |
|
|
|
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_attn.proj.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( |
|
f"double_blocks.{i}.img_attn.proj.bias" |
|
) |
|
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_attn.proj.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( |
|
f"double_blocks.{i}.txt_attn.proj.bias" |
|
) |
|
|
|
|
|
for i in range(num_single_layers): |
|
block_prefix = f"single_transformer_blocks.{i}." |
|
|
|
converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop( |
|
f"single_blocks.{i}.modulation.lin.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop( |
|
f"single_blocks.{i}.modulation.lin.bias" |
|
) |
|
|
|
mlp_hidden_dim = int(inner_dim * mlp_ratio) |
|
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) |
|
q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) |
|
q_bias, k_bias, v_bias, mlp_bias = torch.split( |
|
original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 |
|
) |
|
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) |
|
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) |
|
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) |
|
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) |
|
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) |
|
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) |
|
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) |
|
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) |
|
|
|
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( |
|
f"single_blocks.{i}.norm.query_norm.scale" |
|
) |
|
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( |
|
f"single_blocks.{i}.norm.key_norm.scale" |
|
) |
|
|
|
converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop( |
|
f"single_blocks.{i}.linear2.weight" |
|
) |
|
converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop( |
|
f"single_blocks.{i}.linear2.bias" |
|
) |
|
|
|
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") |
|
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") |
|
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( |
|
original_state_dict.pop("final_layer.adaLN_modulation.1.weight") |
|
) |
|
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( |
|
original_state_dict.pop("final_layer.adaLN_modulation.1.bias") |
|
) |
|
|
|
return converted_state_dict |
|
|
|
|
|
def main(args): |
|
original_ckpt = load_original_checkpoint(args) |
|
has_guidance = any("guidance" in k for k in original_ckpt) |
|
|
|
if args.transformer: |
|
num_layers = 19 |
|
num_single_layers = 38 |
|
inner_dim = 3072 |
|
mlp_ratio = 4.0 |
|
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers( |
|
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio |
|
) |
|
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance) |
|
transformer.load_state_dict(converted_transformer_state_dict, strict=True) |
|
|
|
print( |
|
f"Saving Flux Transformer in Diffusers format. Variant: {'guidance-distilled' if has_guidance else 'timestep-distilled'}" |
|
) |
|
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") |
|
|
|
if args.vae: |
|
config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae") |
|
vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16) |
|
|
|
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config) |
|
vae.load_state_dict(converted_vae_state_dict, strict=True) |
|
vae.to(dtype).save_pretrained(f"{args.output_path}/vae") |
|
|
|
|
|
if __name__ == "__main__": |
|
main(args) |
|
|