Spaces:
Build error
Build error
File size: 6,481 Bytes
8014209 431af92 8014209 d5d52b4 8014209 a7c8de2 431af92 760bfde 54b4787 760bfde 8014209 760bfde 8014209 d5d52b4 8014209 760bfde d5d52b4 8014209 760bfde 8014209 760bfde 8014209 760bfde d5d52b4 8014209 760bfde 8014209 760bfde 8014209 65ebbd2 8014209 65ebbd2 4ea926c 8014209 65ebbd2 760bfde 8014209 d5d52b4 65ebbd2 7e12112 65ebbd2 d5d52b4 65ebbd2 d5d52b4 760bfde 8014209 760bfde 8014209 65ebbd2 760bfde 65ebbd2 8014209 760bfde 8014209 760bfde 8014209 760bfde 8014209 bced9d4 d5d52b4 760bfde 8014209 760bfde 65ebbd2 08e29b9 a7c8de2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import subprocess
from huggingface_hub import HfApi, upload_folder
import gradio as gr
import hf_utils
import utils
subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers.git", "diffs"])
def error_str(error, title="Error"):
return f"""#### {title}
{error}""" if error else ""
def on_token_change(token):
model_names, error = hf_utils.get_my_model_names(token)
if model_names:
model_names.append("Other")
return gr.update(visible=bool(model_names)), gr.update(choices=model_names, value=model_names[0] if model_names else None), gr.update(value=error_str(error))
def url_to_model_id(model_id_str):
return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
def get_ckpt_names(token, radio_model_names, input_model):
model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
if token == "" or model_id == "":
return error_str("Please enter both a token and a model name.", title="Invalid input"), gr.update(choices=[]), gr.update(visible=False)
try:
api = HfApi(token=token)
ckpt_files = [f for f in api.list_repo_files(repo_id=model_id) if f.endswith(".ckpt")]
if not ckpt_files:
return error_str("No checkpoint files found in the model repo."), gr.update(choices=[]), gr.update(visible=False)
return None, gr.update(choices=ckpt_files, value=ckpt_files[0], visible=True), gr.update(visible=True)
except Exception as e:
return error_str(e), gr.update(choices=[]), None
def convert_and_push(radio_model_names, input_model, ckpt_name, token):
model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
try:
model_id = url_to_model_id(model_id)
# 1. Download the checkpoint file
ckpt_path, revision = hf_utils.download_file(repo_id=model_id, filename=ckpt_name, token=token)
# 2. Run the conversion script
subprocess.run(
[
"python3",
"./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
"--checkpoint_path",
ckpt_path,
"--dump_path" ,
model_id,
]
)
# 3. Push to the model repo
commit_message="Add Diffusers weights"
upload_folder(
folder_path=model_id,
repo_id=model_id,
token=token,
create_pr=True,
commit_message=commit_message,
commit_description=f"Add Diffusers weights converted from checkpoint `{ckpt_name}` in revision {revision}",
)
# # 4. Delete the downloaded checkpoint file, yaml files, and the converted model folder
hf_utils.delete_file(revision)
subprocess.run(["rm", "-rf", model_id.split('/')[0]])
import glob
for f in glob.glob("*.yaml*"):
subprocess.run(["rm", "-rf", f])
return f"""Successfully converted the checkpoint and opened a PR to add the weights to the model repo.
You can view and merge the PR [here]({hf_utils.get_pr_url(HfApi(token=token), model_id, commit_message)})."""
except Exception as e:
return error_str(e)
DESCRIPTION = """### Convert a stable diffusion checkpoint to Diffusers🧨
With this space, you can easily convert a CompVis stable diffusion checkpoint to Diffusers and automatically create a pull request to the model repo.
You can choose to convert a checkpoint from one of your own models, or from any other model on the Hub."""
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=11):
with gr.Column():
gr.Markdown("## 1. Load model info")
input_token = gr.Textbox(
max_lines=1,
label="Enter your Hugging Face token",
placeholder="READ permission is enough",
)
gr.Markdown("You can get a token [here](https://huggingface.co/settings/tokens)")
with gr.Group(visible=False) as group_model:
radio_model_names = gr.Radio(label="Choose a model")
input_model = gr.Textbox(
max_lines=1,
label="Model name or URL",
placeholder="username/model_name",
visible=False,
)
btn_get_ckpts = gr.Button("Load")
with gr.Column(scale=10):
with gr.Column(visible=False) as group_convert:
gr.Markdown("## 2. Convert to Diffusers🧨")
radio_ckpts = gr.Radio(label="Choose the checkpoint to convert", visible=False)
gr.Markdown("Conversion may take a few minutes.")
btn_convert = gr.Button("Convert & Push")
error_output = gr.Markdown(label="Output")
input_token.change(
fn=on_token_change,
inputs=input_token,
outputs=[group_model, radio_model_names, error_output],
queue=False,
scroll_to_output=True)
radio_model_names.change(
lambda x: gr.update(visible=x == "Other"),
inputs=radio_model_names,
outputs=input_model,
queue=False,
scroll_to_output=True)
btn_get_ckpts.click(
fn=get_ckpt_names,
inputs=[input_token, radio_model_names, input_model],
outputs=[error_output, radio_ckpts, group_convert],
scroll_to_output=True,
queue=False
)
btn_convert.click(
fn=convert_and_push,
inputs=[radio_model_names, input_model, radio_ckpts, input_token],
outputs=error_output,
scroll_to_output=True
)
# gr.Markdown("""<img src="https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/imgs/diffusers_library.jpg" width="150"/>""")
gr.HTML("""
<p>Space by: <a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a></p><br>
<p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.sd-to-diffusers" alt="visitors"></p>
""")
demo.queue()
demo.launch(share=utils.is_google_colab())
|