File size: 3,571 Bytes
f901652
b024062
 
941ac0f
bb1c525
 
9ef80c7
bb1c525
 
 
 
 
941ac0f
9ef80c7
 
 
 
 
 
941ac0f
 
9ef80c7
 
 
 
 
 
 
f901652
 
31e829c
b024062
 
 
bb1c525
 
 
 
941ac0f
d39504a
 
941ac0f
b024062
941ac0f
 
 
 
 
 
 
 
 
b024062
 
 
 
 
 
 
 
 
 
bb1c525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ef80c7
31e829c
 
9ef80c7
bb1c525
31e829c
 
9ef80c7
 
bb1c525
 
 
 
 
 
 
 
 
0d05f6d
 
 
bb1c525
 
 
 
 
31e829c
bb1c525
941ac0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import spaces
import torch
from PIL import Image
from diffusers import DiffusionPipeline

DIFFUSERS_MODEL_IDS = [
    "stabilityai/stable-diffusion-3-medium-diffusers",
    "stabilityai/stable-diffusion-xl-base-1.0",
    "stabilityai/stable-diffusion-2-1",
    "runwayml/stable-diffusion-v1-5",
]

EXTERNAL_MODEL_URL_MAPPING = {
    "Beautiful Realistic Asians": "https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=full&fp=fp16",
}

MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_URL_MAPPING.keys())

# Global Variables
current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"

if device == 'cuda':
    pipe = DiffusionPipeline.from_pretrained(
        current_model_id,
        torch_dtype=torch.float16,
    ).to(device)



@spaces.GPU()
@torch.inference_mode()
def inference(
        model_id: str,
        prompt: str,
        negative_prompt: str = "",
        progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
    progress(0, "Starting inference...")

    global current_model_id, pipe

    if model_id != current_model_id:
        try:
            pipe = DiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16,
            ).to(device)
            current_model_id = model_id
        except Exception as e:
            raise gr.Error(str(e))

    image = pipe(
        prompt,
        negative_prompt=negative_prompt,
    ).images[0]

    return image


if __name__ == "__main__":
    with gr.Blocks() as demo:
        gr.Markdown(f"# Stable Diffusion Demo")

        with gr.Row():
            with gr.Column():
                inputs = [
                    gr.Dropdown(
                        label="Model ID",
                        choices=MODEL_CHOICES,
                        value="stabilityai/stable-diffusion-3-medium-diffusers",
                    ),
                    gr.Text(label="Prompt", value=""),
                    gr.Text(label="Negative Prompt", value=""),
                ]

                with gr.Accordion("Additional Settings (W.I.P)", open=False):
                    with gr.Row():
                        width_component = gr.Number(label="Width", value=512, step=64, minimum=64, maximum=1024)
                        height_component = gr.Number(label="Height", value=512, step=64, minimum=64, maximum=1024)

                    additional_inputs = [
                        width_component,
                        height_component,
                        gr.Number(label="Guidance Scale", value=7.5, step=0.5, minimum=0, maximum=10),
                        gr.Slider(label="Num Inference Steps", value=None, minimum=1, maximum=1000, step=1)
                    ]

            with gr.Column():
                outputs = [
                    gr.Image(label="Image", type="pil"),
                ]

        gr.Examples(
            examples=[
                ["stabilityai/stable-diffusion-3-medium-diffusers", 'A cat holding a sign that says Hello world', ""],
                ["stabilityai/stable-diffusion-3-medium-diffusers", 'Beautiful pixel art of a Wizard with hovering text "Achivement unlocked: Diffusion models can spell now"'],
                ["stabilityai/stable-diffusion-3-medium-diffusers", 'A corgi wearing sunglasses says "U-Net is OVER!!"'],
            ],
            inputs=inputs
        )

        btn = gr.Button("Generate")
        btn.click(fn=inference, inputs=inputs, outputs=outputs)

    demo.queue().launch()