File size: 6,276 Bytes
cbc2ae6
 
 
 
 
 
 
e5f2ff8
cbc2ae6
 
 
 
 
 
 
 
 
 
 
 
 
b0178d6
cbc2ae6
 
 
 
b0178d6
 
cbc2ae6
 
b0178d6
cbc2ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71cefe5
cbc2ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767c83c
cbc2ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0178d6
cbc2ae6
 
 
 
 
 
b0178d6
 
cbc2ae6
 
 
 
 
8a3ba61
 
71cefe5
 
 
cbc2ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gradio
import subprocess
from PIL import Image
import torch, torch.backends.cudnn, torch.backends.cuda
from min_dalle import MinDalle
from emoji import demojize
import string

def filename_from_text(text: str) -> str:
    text = demojize(text, delimiters=['', ''])
    text = text.lower().encode('ascii', errors='ignore').decode()
    allowed_chars = string.ascii_lowercase + ' '
    text = ''.join(i for i in text.lower() if i in allowed_chars)
    text = text[:64]
    text = '-'.join(text.strip().split())
    if len(text) == 0: text = 'blank'
    return text

def log_gpu_memory():
    print(subprocess.check_output('nvidia-smi').decode('utf-8'))

# log_gpu_memory()

model = MinDalle(
    is_mega=True, 
    is_reusable=True,
    device='cpu',
    # dtype=torch.float32
)

# log_gpu_memory()

def run_model(
    text: str,
    grid_size: int,
    is_seamless: bool,
    save_as_png: bool,
    temperature: float,
    supercondition: str,
    top_k: str
) -> str:
    torch.set_grad_enabled(False)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True

    print('text:', text)
    print('grid_size:', grid_size) 
    print('is_seamless:', is_seamless)
    print('temperature:', temperature)
    print('supercondition:', supercondition)
    print('top_k:', top_k)
    
    try:
        temperature = float(temperature)
        assert(temperature > 1e-6)
    except:
        raise Exception('Temperature must be a positive nonzero number')
    try:
        grid_size = int(grid_size)
        assert(grid_size <= 5)
        assert(grid_size >= 1)
    except:
        raise Exception('Grid size must be between 1 and 5')
    try:
        top_k = int(top_k)
        assert(top_k <= 16384)
        assert(top_k >= 1)
    except:
        raise Exception('Top k must be between 1 and 16384')

    with torch.no_grad():
        image = model.generate_image(
            text = text,
            seed = -1,
            grid_size = grid_size,
            is_seamless = bool(is_seamless),
            temperature = temperature,
            supercondition_factor = float(supercondition),
            top_k = top_k,
            is_verbose = True
        )

    log_gpu_memory()

    ext = 'png' if bool(save_as_png) else 'jpg'
    filename = filename_from_text(text)
    image_path = '{}.{}'.format(filename, ext)
    image.save(image_path)

    return image_path

demo = gradio.Blocks(analytics_enabled=True)

with demo:
    with gradio.Row():
        with gradio.Column():
            input_text = gradio.Textbox(
                label='Input Text', 
                value='Rusty Iron Man suit found abandoned in the woods being reclaimed by nature',
                lines=3
            )
            run_button = gradio.Button(value='Generate Image').style(full_width=True)
            output_image = gradio.Image(
                value='examples/dog.jpg',
                label='Output Image',
                type='file',
                interactive=False
            )

        with gradio.Column():
            gradio.Markdown('## Settings')
            with gradio.Row():
                grid_size = gradio.Slider(
                    label='Grid Size',
                    value=3,
                    minimum=1, 
                    maximum=5,
                    step=1
                )
                save_as_png = gradio.Checkbox(
                    label='Output PNG',
                    value=False
                )
                is_seamless = gradio.Checkbox(
                    label='Seamless',
                    value=False
                )
            gradio.Markdown('#### Advanced')
            with gradio.Row():
                temperature = gradio.Number(
                    label='Temperature',
                    value=1
                )
                top_k = gradio.Dropdown(
                    label='Top-k',
                    choices=[str(2 ** i) for i in range(15)],
                    value='128'
                )
                supercondition = gradio.Dropdown(
                    label='Super Condition',
                    choices=[str(2 ** i) for i in range(2, 7)],
                    value='16'
                )

            gradio.Markdown(
                """
                #### Parameter
                - **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image.
                - **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds.
                - **Seamless**: Tile images in image token space instead of pixel space.
                - **Temperature**: High temperature increases the probability of sampling low scoring image tokens.
                - **Top-k**: Each image token is sampled from the top-k scoring tokens.
                - **Super Condition**: Higher values can result in better agreement with the text.

                #### 
                """
            )

    gradio.Examples(
        examples=[
            ['A white cat with golden sunglasses on, pink background, studio lighting, 4k, award winning photography', 2, 'examples/cat.png'],
            ['an astronaut dancing on the moon’s surface, close-up photo', 2, 'examples/astronaut.png'],
            ['A photo of a Samoyed dog with its tongue out hugging a white Siamese cat', 5, 'examples/dog.png'],
            ['Dragons of Earth, Wind, Fire, powering up a huge sphere of compressed energy, digital art', 2, 'examples/dragon.png'],
            ['A snowboarder jumping in the air while coming down a ski mountain, concept art, artstation, unreal engine, 3d render, HD, Bokeh', 3, 'examples/snow.png'],
        ],
        inputs=[
            input_text,
            grid_size,
            output_image
        ],
        examples_per_page=20
    )

    run_button.click(
        fn=run_model, 
        inputs=[
            input_text,
            grid_size,
            is_seamless,
            save_as_png,
            temperature,
            supercondition,
            top_k
        ], 
        outputs=[
            output_image
        ]
    )


demo.launch()