As a work-around, this would be fine...

#17
by John6666 - opened
Files changed (1) hide show
  1. convert.py +105 -103
convert.py CHANGED
@@ -1,103 +1,105 @@
1
- import gradio as gr
2
- import requests
3
- import os
4
- import shutil
5
- from pathlib import Path
6
- import tempfile
7
- from tempfile import TemporaryDirectory
8
-
9
-
10
- from typing import Optional
11
-
12
- import torch
13
- from io import BytesIO
14
-
15
- from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
16
- from huggingface_hub.file_download import repo_folder_name
17
- from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
18
- download_from_original_stable_diffusion_ckpt, download_controlnet_from_original_ckpt
19
- )
20
- from transformers import CONFIG_MAPPING
21
-
22
-
23
- COMMIT_MESSAGE = " This PR adds fp32 and fp16 weights in PyTorch and safetensors format to {}"
24
-
25
-
26
- def convert_single(model_id: str, token:str, filename: str, model_type: str, sample_size: int, scheduler_type: str, extract_ema: bool, folder: str, progress):
27
- from_safetensors = filename.endswith(".safetensors")
28
-
29
- progress(0, desc="Downloading model")
30
- local_file = os.path.join(model_id, filename)
31
- ckpt_file = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename, token=token)
32
-
33
- if model_type == "v1":
34
- config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
35
- elif model_type == "v2":
36
- if sample_size == 512:
37
- config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference.yaml"
38
- else:
39
- config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
40
- elif model_type == "ControlNet":
41
- config_url = (Path(model_id)/"resolve/main"/filename).with_suffix(".yaml")
42
- config_url = "https://huggingface.co/" + str(config_url)
43
-
44
- #config_file = BytesIO(requests.get(config_url).content)
45
-
46
- response = requests.get(config_url)
47
- with tempfile.NamedTemporaryFile(delete=False, mode='wb') as tmp_file:
48
- tmp_file.write(response.content)
49
- temp_config_file_path = tmp_file.name
50
-
51
- if model_type == "ControlNet":
52
- progress(0.2, desc="Converting ControlNet Model")
53
- pipeline = download_controlnet_from_original_ckpt(ckpt_file, temp_config_file_path, image_size=sample_size, from_safetensors=from_safetensors, extract_ema=extract_ema)
54
- to_args = {"dtype": torch.float16}
55
- else:
56
- progress(0.1, desc="Converting Model")
57
- pipeline = download_from_original_stable_diffusion_ckpt(ckpt_file, temp_config_file_path, image_size=sample_size, scheduler_type=scheduler_type, from_safetensors=from_safetensors, extract_ema=extract_ema)
58
- to_args = {"torch_dtype": torch.float16}
59
-
60
- pipeline.save_pretrained(folder)
61
- pipeline.save_pretrained(folder, safe_serialization=True)
62
-
63
- pipeline = pipeline.to(**to_args)
64
- pipeline.save_pretrained(folder, variant="fp16")
65
- pipeline.save_pretrained(folder, safe_serialization=True, variant="fp16")
66
-
67
- return folder
68
-
69
-
70
- def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
71
- try:
72
- discussions = api.get_repo_discussions(repo_id=model_id)
73
- except Exception:
74
- return None
75
- for discussion in discussions:
76
- if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
77
- details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
78
- if details.target_branch == "refs/heads/main":
79
- return discussion
80
-
81
-
82
- def convert(token: str, model_id: str, filename: str, model_type: str, sample_size: int = 512, scheduler_type: str = "pndm", extract_ema: bool = True, progress=gr.Progress()):
83
- api = HfApi()
84
-
85
- pr_title = "Adding `diffusers` weights of this model"
86
-
87
- with TemporaryDirectory() as d:
88
- folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
89
- os.makedirs(folder)
90
- new_pr = None
91
- try:
92
- folder = convert_single(model_id, token, filename, model_type, sample_size, scheduler_type, extract_ema, folder, progress)
93
- progress(0.7, desc="Uploading to Hub")
94
- new_pr = api.upload_folder(folder_path=folder, path_in_repo="./", repo_id=model_id, repo_type="model", token=token, commit_message=pr_title, commit_description=COMMIT_MESSAGE.format(model_id), create_pr=True)
95
- pr_number = new_pr.split("%2F")[-1].split("/")[0]
96
- link = f"Pr created at: {'https://huggingface.co/' + os.path.join(model_id, 'discussions', pr_number)}"
97
- progress(1, desc="Done")
98
- except Exception as e:
99
- raise gr.exceptions.Error(str(e))
100
- finally:
101
- shutil.rmtree(folder)
102
-
103
- return link
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+ import tempfile
7
+ from tempfile import TemporaryDirectory
8
+
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from io import BytesIO
14
+
15
+ from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
16
+ from huggingface_hub.file_download import repo_folder_name
17
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
18
+ download_from_original_stable_diffusion_ckpt, download_controlnet_from_original_ckpt
19
+ )
20
+ from transformers import CONFIG_MAPPING
21
+
22
+
23
+ COMMIT_MESSAGE = " This PR adds fp32 and fp16 weights in PyTorch and safetensors format to {}"
24
+
25
+
26
+ def convert_single(model_id: str, token:str, filename: str, model_type: str, sample_size: int, scheduler_type: str, extract_ema: bool, folder: str, progress):
27
+ from_safetensors = filename.endswith(".safetensors")
28
+
29
+ progress(0, desc="Downloading model")
30
+ local_file = os.path.join(model_id, filename)
31
+ ckpt_file = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename, token=token)
32
+
33
+ if model_type == "v1":
34
+ config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
35
+ elif model_type == "v2":
36
+ if sample_size == 512:
37
+ config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference.yaml"
38
+ else:
39
+ config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
40
+ elif model_type == "ControlNet":
41
+ config_url = (Path(model_id)/"resolve/main"/filename).with_suffix(".yaml")
42
+ config_url = "https://huggingface.co/" + str(config_url)
43
+
44
+ #config_file = BytesIO(requests.get(config_url).content)
45
+
46
+ response = requests.get(config_url)
47
+ with tempfile.NamedTemporaryFile(delete=False, mode='wb') as tmp_file:
48
+ tmp_file.write(response.content)
49
+ temp_config_file_path = tmp_file.name
50
+
51
+ if model_type == "ControlNet":
52
+ progress(0.2, desc="Converting ControlNet Model")
53
+ pipeline = download_controlnet_from_original_ckpt(ckpt_file, temp_config_file_path, image_size=sample_size, from_safetensors=from_safetensors, extract_ema=extract_ema)
54
+ to_args = {"dtype": torch.float16}
55
+ else:
56
+ progress(0.1, desc="Converting Model")
57
+ pipeline = download_from_original_stable_diffusion_ckpt(ckpt_file, temp_config_file_path, image_size=sample_size, scheduler_type=scheduler_type, from_safetensors=from_safetensors, extract_ema=extract_ema)
58
+ to_args = {"torch_dtype": torch.float16}
59
+
60
+ pipeline.save_pretrained(folder)
61
+ pipeline.save_pretrained(folder, safe_serialization=True)
62
+
63
+ #pipeline = pipeline.to(**to_args)
64
+ from diffusers import StableDiffusionPipeline
65
+ pipeline = StableDiffusionPipeline.from_pretrained(folder, use_safetensors=True, torch_dtype=torch.float16)
66
+ pipeline.save_pretrained(folder, variant="fp16")
67
+ pipeline.save_pretrained(folder, safe_serialization=True, variant="fp16")
68
+
69
+ return folder
70
+
71
+
72
+ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
73
+ try:
74
+ discussions = api.get_repo_discussions(repo_id=model_id)
75
+ except Exception:
76
+ return None
77
+ for discussion in discussions:
78
+ if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
79
+ details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
80
+ if details.target_branch == "refs/heads/main":
81
+ return discussion
82
+
83
+
84
+ def convert(token: str, model_id: str, filename: str, model_type: str, sample_size: int = 512, scheduler_type: str = "pndm", extract_ema: bool = True, progress=gr.Progress()):
85
+ api = HfApi()
86
+
87
+ pr_title = "Adding `diffusers` weights of this model"
88
+
89
+ with TemporaryDirectory() as d:
90
+ folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
91
+ os.makedirs(folder)
92
+ new_pr = None
93
+ try:
94
+ folder = convert_single(model_id, token, filename, model_type, sample_size, scheduler_type, extract_ema, folder, progress)
95
+ progress(0.7, desc="Uploading to Hub")
96
+ new_pr = api.upload_folder(folder_path=folder, path_in_repo="./", repo_id=model_id, repo_type="model", token=token, commit_message=pr_title, commit_description=COMMIT_MESSAGE.format(model_id), create_pr=True)
97
+ pr_number = new_pr.split("%2F")[-1].split("/")[0]
98
+ link = f"Pr created at: {'https://huggingface.co/' + os.path.join(model_id, 'discussions', pr_number)}"
99
+ progress(1, desc="Done")
100
+ except Exception as e:
101
+ raise gr.exceptions.Error(str(e))
102
+ finally:
103
+ shutil.rmtree(folder)
104
+
105
+ return link