File size: 6,191 Bytes
c3d14af
 
f901652
b024062
 
941ac0f
bb1c525
207d269
bb1c525
9ef80c7
e11bea1
bb1c525
 
 
 
e11bea1
 
 
bb1c525
941ac0f
2b72305
 
9ef80c7
 
2b72305
9ef80c7
941ac0f
9ef80c7
4fe0601
9ef80c7
4fe0601
9ef80c7
 
4fe0601
9ef80c7
4fe0601
 
f901652
 
c3d14af
 
 
 
 
 
 
 
 
 
aed67d7
c3d14af
 
 
 
 
aed67d7
c3d14af
 
 
 
 
 
 
 
 
4fe0601
 
 
 
 
 
 
 
 
c3d14af
 
 
9c02482
b024062
 
bb1c525
c3d14af
bb1c525
207d269
 
 
 
 
aed67d7
bb1c525
941ac0f
d39504a
 
941ac0f
b024062
aed67d7
941ac0f
 
4fe0601
2b72305
 
 
941ac0f
 
4fe0601
941ac0f
4fe0601
941ac0f
 
 
b024062
aed67d7
 
 
 
4fe0601
 
 
 
 
207d269
aed67d7
207d269
b024062
 
207d269
 
 
 
 
 
 
 
 
 
 
b024062
 
 
 
 
c3d14af
 
 
bb1c525
 
 
 
c3d14af
 
207d269
 
 
 
 
bb1c525
aed67d7
c3d14af
aed67d7
207d269
9ef80c7
207d269
 
 
9ef80c7
207d269
4fe0601
 
 
bb1c525
aed67d7
 
 
bb1c525
207d269
 
 
 
c3d14af
207d269
 
 
 
 
 
aed67d7
207d269
bb1c525
e11bea1
207d269
 
 
 
 
e11bea1
cf49511
c3d14af
cf49511
bb1c525
 
941ac0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import dataclasses

import gradio as gr
import spaces
import torch
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import make_image_grid

DIFFUSERS_MODEL_IDS = [
    # SD Models
    "stabilityai/stable-diffusion-3-medium-diffusers",
    "stabilityai/stable-diffusion-xl-base-1.0",
    "stabilityai/stable-diffusion-2-1",
    "runwayml/stable-diffusion-v1-5",

    # Other Models
    "Prgckwb/trpfrog-diffusion",
]

EXTERNAL_MODEL_MAPPING = {
    "Beautiful Realistic Asians": "checkpoints/diffusers/Beautiful Realistic Asians v7",
}

MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_MAPPING.keys())

current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == 'cuda':
    dtype = torch.float16
    pipe = DiffusionPipeline.from_pretrained(
        current_model_id,
        torch_dtype=dtype,
    ).to(device)
else:
    dtype = torch.float32


@dataclasses.dataclass
class Input:
    prompt: str
    model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers"
    negative_prompt: str = ''
    width: int = 1024
    height: int = 1024
    guidance_scale: float = 7.5
    num_inference_step: int = 28
    num_images: int = 4
    safety_checker: bool = True

    def to_list(self):
        return [
            self.prompt, self.model_id, self.negative_prompt,
            self.width, self.height, self.guidance_scale,
            self.num_inference_step, self.num_images, self.safety_checker
        ]


EXAMPLES = [
    Input(prompt='A cat holding a sign that says Hello world').to_list(),
    Input(
        prompt='Beautiful pixel art of a Wizard with hovering text "Achivement unlocked: Diffusion models can spell now"'
    ).to_list(),
    Input(prompt='A corgi wearing sunglasses says "U-Net is OVER!!"').to_list(),
    Input(
        prompt='Cinematic Photo of a beautiful korean fashion model bokeh train',
        model_id='Beautiful Realistic Asians',
        negative_prompt='(worst_quality:2.0) (MajicNegative_V2:0.8) BadNegAnatomyV1-neg bradhands cartoon, cgi, render, illustration, painting, drawing',
        width=512,
        height=512,
        guidance_scale=5.0,
        num_inference_step=50,
    ).to_list()
]


@spaces.GPU(duration=120)
@torch.inference_mode()
def inference(
        prompt: str,
        model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers",
        negative_prompt: str = "",
        width: int = 512,
        height: int = 512,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 50,
        num_images: int = 4,
        safety_checker: bool = True,
        progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
    progress(0, "Starting inference...")

    global current_model_id, pipe

    progress(0.1, 'Loading pipeline...')
    if model_id != current_model_id:
        try:
            # For NOT Diffusers' Models
            if model_id not in DIFFUSERS_MODEL_IDS:
                model_id = EXTERNAL_MODEL_MAPPING[model_id]

            pipe = DiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=dtype,
            ).to(device)

            current_model_id = model_id
        except Exception as e:
            raise gr.Error(str(e))

    if not safety_checker:
        pipe.safety_checker = None

    progress(0.3, 'Loading Textual Inversion...')
    # Load Textual Inversion
    pipe.load_textual_inversion(
        "checkpoints/embeddings/BadNegAnatomyV1 neg.pt", token='BadNegAnatomyV1-neg'
    )

    # Generation
    progress(0.4, 'Generating images...')
    images = pipe(
        prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=num_images,
    ).images

    if num_images % 2 == 1:
        image = make_image_grid(images, rows=num_images, cols=1)
    else:
        image = make_image_grid(images, rows=2, cols=num_images // 2)

    return image


if __name__ == "__main__":
    theme = gr.themes.Default(primary_hue=gr.themes.colors.emerald)

    with gr.Blocks(theme=theme) as demo:
        gr.Markdown(f"# Stable Diffusion Demo")

        with gr.Row():
            with gr.Column():
                prompt = gr.Text(label="Prompt", placeholder="Enter a prompt here")

                model_id = gr.Dropdown(
                    label="Model ID",
                    choices=MODEL_CHOICES,
                    value="stabilityai/stable-diffusion-3-medium-diffusers",
                )

                # Additional Input Settings
                with gr.Accordion("Additional Settings", open=False):
                    negative_prompt = gr.Text(label="Negative Prompt", value="", )

                    with gr.Row():
                        width = gr.Number(label="Width", value=512, step=64, minimum=64, maximum=2048)
                        height = gr.Number(label="Height", value=512, step=64, minimum=64, maximum=2048)
                        num_images = gr.Number(label="Num Images", value=4, minimum=1, maximum=10, step=1)

                    guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.5, minimum=0, maximum=10)
                    num_inference_step = gr.Slider(
                        label="Num Inference Steps", value=50, minimum=1, maximum=100, step=2
                    )

                    with gr.Row():
                        safety_checker = gr.Checkbox(value=True, label='Use Safety Checker')

            with gr.Column():
                output_image = gr.Image(label="Image", type="pil")

        inputs = [
            prompt,
            model_id,
            negative_prompt,
            width,
            height,
            guidance_scale,
            num_inference_step,
            num_images,
            safety_checker
        ]

        btn = gr.Button("Generate")
        btn.click(
            fn=inference,
            inputs=inputs,
            outputs=output_image
        )

        gr.Examples(
            examples=EXAMPLES,
            inputs=inputs,
        )

    demo.queue().launch()