|
import argparse |
|
|
|
import safetensors.torch |
|
|
|
from diffusers import AutoencoderTiny |
|
|
|
|
|
""" |
|
Example - From the diffusers root directory: |
|
|
|
Download the weights: |
|
```sh |
|
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors |
|
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors |
|
``` |
|
|
|
Convert the model: |
|
```sh |
|
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \ |
|
--encoder_ckpt_path taesd_encoder.safetensors \ |
|
--decoder_ckpt_path taesd_decoder.safetensors \ |
|
--dump_path taesd-diffusers |
|
``` |
|
""" |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") |
|
parser.add_argument( |
|
"--encoder_ckpt_path", |
|
default=None, |
|
type=str, |
|
required=True, |
|
help="Path to the encoder ckpt.", |
|
) |
|
parser.add_argument( |
|
"--decoder_ckpt_path", |
|
default=None, |
|
type=str, |
|
required=True, |
|
help="Path to the decoder ckpt.", |
|
) |
|
parser.add_argument( |
|
"--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format." |
|
) |
|
args = parser.parse_args() |
|
|
|
print("Loading the original state_dicts of the encoder and the decoder...") |
|
encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path) |
|
decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path) |
|
|
|
print("Populating the state_dicts in the diffusers format...") |
|
tiny_autoencoder = AutoencoderTiny() |
|
new_state_dict = {} |
|
|
|
|
|
for k in encoder_state_dict: |
|
new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]}) |
|
|
|
|
|
for k in decoder_state_dict: |
|
layer_id = int(k.split(".")[0]) - 1 |
|
new_k = str(layer_id) + "." + ".".join(k.split(".")[1:]) |
|
new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]}) |
|
|
|
|
|
|
|
tiny_autoencoder.load_state_dict(new_state_dict) |
|
print("Population successful, serializing...") |
|
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors) |
|
|