File size: 6,438 Bytes
c3d14af
 
f901652
b024062
 
941ac0f
bb1c525
207d269
bb1c525
9ef80c7
e11bea1
bb1c525
 
 
 
e11bea1
 
 
bb1c525
2b72305
 
9ef80c7
2b72305
9ef80c7
941ac0f
4fe0601
cc8ba3d
f901652
 
c3d14af
 
 
 
 
 
 
 
 
 
255b6b2
 
bf30e39
c3d14af
 
 
 
 
255b6b2
bf30e39
c3d14af
 
 
 
 
 
 
 
 
4fe0601
 
 
1f3da27
4fe0601
 
 
 
 
c3d14af
 
 
9c02482
b024062
 
bb1c525
c3d14af
bb1c525
207d269
 
 
 
 
aed67d7
cc8ba3d
59df77b
bb1c525
941ac0f
d39504a
 
941ac0f
b024062
aed67d7
1c38bbc
 
2b72305
1c38bbc
 
 
 
4fe0601
1c38bbc
b024062
aed67d7
5566482
aed67d7
dc8517c
 
 
 
 
 
4fe0601
207d269
aed67d7
cc8ba3d
 
 
 
 
59df77b
 
207d269
b024062
 
207d269
 
 
 
 
59df77b
207d269
 
 
 
 
 
b024062
 
 
 
 
c3d14af
 
 
bb1c525
 
 
 
c3d14af
 
207d269
 
 
 
 
bb1c525
aed67d7
c3d14af
aed67d7
207d269
9ef80c7
207d269
 
 
59df77b
9ef80c7
207d269
4fe0601
 
 
bb1c525
aed67d7
255b6b2
 
aed67d7
bb1c525
207d269
 
 
 
c3d14af
207d269
 
 
 
 
 
255b6b2
 
59df77b
207d269
bb1c525
e11bea1
207d269
 
 
 
 
e11bea1
cf49511
c3d14af
cf49511
bb1c525
 
941ac0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import dataclasses

import gradio as gr
import spaces
import torch
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import make_image_grid

DIFFUSERS_MODEL_IDS = [
    # SD Models
    "stabilityai/stable-diffusion-3-medium-diffusers",
    "stabilityai/stable-diffusion-xl-base-1.0",
    "stabilityai/stable-diffusion-2-1",
    "runwayml/stable-diffusion-v1-5",

    # Other Models
    "Prgckwb/trpfrog-diffusion",
]
EXTERNAL_MODEL_MAPPING = {
    "Beautiful Realistic Asians": "checkpoints/diffusers/Beautiful Realistic Asians v7",
}
MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_MAPPING.keys())

current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = None


@dataclasses.dataclass
class Input:
    prompt: str
    model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers"
    negative_prompt: str = ''
    width: int = 1024
    height: int = 1024
    guidance_scale: float = 7.5
    num_inference_step: int = 28
    num_images: int = 4
    use_safety_checker: bool = True
    use_model_offload: bool = False
    seed: int = 8888

    def to_list(self):
        return [
            self.prompt, self.model_id, self.negative_prompt,
            self.width, self.height, self.guidance_scale,
            self.num_inference_step, self.num_images, self.use_safety_checker, self.use_model_offload,
            self.seed
        ]


EXAMPLES = [
    Input(prompt='A cat holding a sign that says Hello world').to_list(),
    Input(
        prompt='Beautiful pixel art of a Wizard with hovering text "Achivement unlocked: Diffusion models can spell now"'
    ).to_list(),
    Input(prompt='A corgi wearing sunglasses says "U-Net is OVER!!"').to_list(),
    Input(
        prompt='Cinematic Photo of a beautiful korean fashion model bokeh train',
        model_id='Beautiful Realistic Asians',
        negative_prompt='worst_quality, BadNegAnatomyV1-neg, bradhands cartoon, cgi, render, illustration, painting, drawing',
        width=512,
        height=512,
        guidance_scale=5.0,
        num_inference_step=50,
    ).to_list()
]


@spaces.GPU(duration=120)
@torch.inference_mode()
def inference(
        prompt: str,
        model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers",
        negative_prompt: str = "",
        width: int = 512,
        height: int = 512,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 50,
        num_images: int = 4,
        safety_checker: bool = True,
        use_model_offload:bool = False,
        seed: int = 8888,
        progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
    progress(0, "Starting inference...")

    global current_model_id, pipe

    progress(0.1, 'Loading pipeline...')
    if model_id not in DIFFUSERS_MODEL_IDS:
        model_id = EXTERNAL_MODEL_MAPPING[model_id]

    pipe = DiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
    )

    current_model_id = model_id

    if not safety_checker:
        pipe.safety_checker = None

    if model_id not in DIFFUSERS_MODEL_IDS:
        progress(0.3, 'Loading Textual Inversion...')
        # Load Textual Inversion
        pipe.load_textual_inversion(
            "checkpoints/embeddings/BadNegAnatomyV1 neg.pt", token='BadNegAnatomyV1-neg'
        )

    # Generation
    progress(0.4, 'Generating images...')
    if use_model_offload:
        pipe.enable_model_cpu_offload()
    else:
        pipe = pipe.to('cuda')

    generator = torch.Generator(device=device).manual_seed(seed)

    images = pipe(
        prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=num_images,
        generator=generator,
    ).images

    if num_images % 2 == 1:
        image = make_image_grid(images, rows=num_images, cols=1)
    else:
        image = make_image_grid(images, rows=2, cols=num_images // 2)

    return image


if __name__ == "__main__":
    theme = gr.themes.Default(primary_hue=gr.themes.colors.emerald)

    with gr.Blocks(theme=theme) as demo:
        gr.Markdown(f"# Stable Diffusion Demo")

        with gr.Row():
            with gr.Column():
                prompt = gr.Text(label="Prompt", placeholder="Enter a prompt here")

                model_id = gr.Dropdown(
                    label="Model ID",
                    choices=MODEL_CHOICES,
                    value="stabilityai/stable-diffusion-3-medium-diffusers",
                )

                # Additional Input Settings
                with gr.Accordion("Additional Settings", open=False):
                    negative_prompt = gr.Text(label="Negative Prompt", value="", )

                    with gr.Row():
                        width = gr.Number(label="Width", value=512, step=64, minimum=64, maximum=2048)
                        height = gr.Number(label="Height", value=512, step=64, minimum=64, maximum=2048)
                        num_images = gr.Number(label="Num Images", value=4, minimum=1, maximum=10, step=1)
                        seed = gr.Number(label="Seed", value=8888,  step=1)

                    guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.5, minimum=0, maximum=10)
                    num_inference_step = gr.Slider(
                        label="Num Inference Steps", value=50, minimum=1, maximum=100, step=2
                    )

                    with gr.Row():
                        use_safety_checker = gr.Checkbox(value=True, label='Use Safety Checker')
                        use_model_offload = gr.Checkbox(value=False, label='Use Model Offload')

            with gr.Column():
                output_image = gr.Image(label="Image", type="pil")

        inputs = [
            prompt,
            model_id,
            negative_prompt,
            width,
            height,
            guidance_scale,
            num_inference_step,
            num_images,
            use_safety_checker,
            use_model_offload,
            seed,
        ]

        btn = gr.Button("Generate")
        btn.click(
            fn=inference,
            inputs=inputs,
            outputs=output_image
        )

        gr.Examples(
            examples=EXAMPLES,
            inputs=inputs,
        )

    demo.queue().launch()