Sergidev commited on
Commit
7117c2e
1 Parent(s): 488e83c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -62
app.py CHANGED
@@ -27,7 +27,6 @@ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
27
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
30
- THUMBNAIL_SIZE = (128, 128) # Size for thumbnails
31
 
32
  MODEL = os.getenv(
33
  "MODEL",
@@ -39,11 +38,33 @@ torch.backends.cudnn.benchmark = False
39
 
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
 
42
- # Store the generation history
43
- generation_history = []
44
 
45
  def load_pipeline(model_name):
46
- # ... (rest of the function remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  @spaces.GPU
49
  def generate(
@@ -61,29 +82,95 @@ def generate(
61
  upscale_by: float = 1.5,
62
  progress=gr.Progress(track_tqdm=True),
63
  ) -> Image:
64
- # ... (rest of the function remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  try:
67
- # ... (existing code for image generation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- if images:
70
- # Create thumbnail
71
- thumbnail = images[0].copy()
72
- thumbnail.thumbnail(THUMBNAIL_SIZE)
73
-
74
- # Add to generation history
75
- generation_history.append({
 
 
 
 
76
  "prompt": prompt,
77
- "thumbnail": thumbnail,
78
  "metadata": metadata
79
  })
80
 
81
- if IS_COLAB:
82
- for image in images:
83
- filepath = utils.save_image(image, metadata, OUTPUT_DIR)
84
- logger.info(f"Image saved as {filepath} with metadata")
85
-
86
- return images, metadata, update_history()
87
  except Exception as e:
88
  logger.exception(f"An error occurred: {e}")
89
  raise
@@ -93,19 +180,6 @@ def generate(
93
  pipe.scheduler = backup_scheduler
94
  utils.free_memory()
95
 
96
- def update_history():
97
- history_html = "<div style='display: flex; flex-wrap: wrap;'>"
98
- for item in reversed(generation_history[-10:]): # Show last 10 entries
99
- thumbnail_path = f"data:image/png;base64,{utils.image_to_base64(item['thumbnail'])}"
100
- history_html += f"""
101
- <div style='margin: 5px; text-align: center;'>
102
- <img src='{thumbnail_path}' style='width: 100px; height: 100px; object-fit: cover;'>
103
- <p style='font-size: 12px; margin: 5px 0;'>{item['prompt'][:50]}...</p>
104
- </div>
105
- """
106
- history_html += "</div>"
107
- return history_html
108
-
109
  if torch.cuda.is_available():
110
  pipe = load_pipeline(MODEL)
111
  logger.info("Loaded on Device!")
@@ -128,43 +202,133 @@ with gr.Blocks(css="style.css") as demo:
128
  )
129
  with gr.Group():
130
  with gr.Row():
131
- prompt = gr.Text(
132
- label="Prompt",
133
- show_label=False,
134
- max_lines=5,
135
- placeholder="Enter your prompt",
136
- container=False,
137
- )
138
- run_button = gr.Button(
139
- "Generate",
140
- variant="primary",
141
- scale=0
142
- )
143
- result = gr.Gallery(
144
- label="Result",
145
- columns=1,
146
- preview=True,
147
- show_label=False
148
- )
149
-
150
- # Add the history display
151
- history_display = gr.HTML(label="Generation History")
 
 
 
 
 
 
 
152
 
153
  with gr.Accordion(label="Advanced Settings", open=False):
154
- # ... (rest of the UI components remain the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  with gr.Accordion(label="Generation Parameters", open=False):
157
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
158
-
159
  gr.Examples(
160
  examples=config.examples,
161
  inputs=prompt,
162
- outputs=[result, gr_metadata, history_display],
163
  fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
164
  cache_examples=CACHE_EXAMPLES,
165
  )
166
-
167
- # ... (rest of the event handlers remain the same)
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  inputs = [
170
  prompt,
@@ -190,7 +354,7 @@ with gr.Blocks(css="style.css") as demo:
190
  ).then(
191
  fn=generate,
192
  inputs=inputs,
193
- outputs=[result, gr_metadata, history_display],
194
  api_name="run",
195
  )
196
  negative_prompt.submit(
@@ -202,7 +366,7 @@ with gr.Blocks(css="style.css") as demo:
202
  ).then(
203
  fn=generate,
204
  inputs=inputs,
205
- outputs=[result, gr_metadata, history_display],
206
  api_name=False,
207
  )
208
  run_button.click(
@@ -214,7 +378,7 @@ with gr.Blocks(css="style.css") as demo:
214
  ).then(
215
  fn=generate,
216
  inputs=inputs,
217
- outputs=[result, gr_metadata, history_display],
218
  api_name=False,
219
  )
220
 
 
27
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
 
30
 
31
  MODEL = os.getenv(
32
  "MODEL",
 
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
+ # Add a new global variable to store the image history
42
+ image_history = []
43
 
44
  def load_pipeline(model_name):
45
+ vae = AutoencoderKL.from_pretrained(
46
+ "madebyollin/sdxl-vae-fp16-fix",
47
+ torch_dtype=torch.float16,
48
+ )
49
+ pipeline = (
50
+ StableDiffusionXLPipeline.from_single_file
51
+ if MODEL.endswith(".safetensors")
52
+ else StableDiffusionXLPipeline.from_pretrained
53
+ )
54
+
55
+ pipe = pipeline(
56
+ model_name,
57
+ vae=vae,
58
+ torch_dtype=torch.float16,
59
+ custom_pipeline="lpw_stable_diffusion_xl",
60
+ use_safetensors=True,
61
+ add_watermarker=False,
62
+ use_auth_token=HF_TOKEN,
63
+ variant="fp16",
64
+ )
65
+
66
+ pipe.to(device)
67
+ return pipe
68
 
69
  @spaces.GPU
70
  def generate(
 
82
  upscale_by: float = 1.5,
83
  progress=gr.Progress(track_tqdm=True),
84
  ) -> Image:
85
+ generator = utils.seed_everything(seed)
86
+
87
+ width, height = utils.aspect_ratio_handler(
88
+ aspect_ratio_selector,
89
+ custom_width,
90
+ custom_height,
91
+ )
92
+
93
+ width, height = utils.preprocess_image_dimensions(width, height)
94
+
95
+ backup_scheduler = pipe.scheduler
96
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
97
+
98
+ if use_upscaler:
99
+ upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
100
+ metadata = {
101
+ "prompt": prompt,
102
+ "negative_prompt": negative_prompt,
103
+ "resolution": f"{width} x {height}",
104
+ "guidance_scale": guidance_scale,
105
+ "num_inference_steps": num_inference_steps,
106
+ "seed": seed,
107
+ "sampler": sampler,
108
+ }
109
+
110
+ if use_upscaler:
111
+ new_width = int(width * upscale_by)
112
+ new_height = int(height * upscale_by)
113
+ metadata["use_upscaler"] = {
114
+ "upscale_method": "nearest-exact",
115
+ "upscaler_strength": upscaler_strength,
116
+ "upscale_by": upscale_by,
117
+ "new_resolution": f"{new_width} x {new_height}",
118
+ }
119
+ else:
120
+ metadata["use_upscaler"] = None
121
+ logger.info(json.dumps(metadata, indent=4))
122
 
123
  try:
124
+ if use_upscaler:
125
+ latents = pipe(
126
+ prompt=prompt,
127
+ negative_prompt=negative_prompt,
128
+ width=width,
129
+ height=height,
130
+ guidance_scale=guidance_scale,
131
+ num_inference_steps=num_inference_steps,
132
+ generator=generator,
133
+ output_type="latent",
134
+ ).images
135
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
136
+ images = upscaler_pipe(
137
+ prompt=prompt,
138
+ negative_prompt=negative_prompt,
139
+ image=upscaled_latents,
140
+ guidance_scale=guidance_scale,
141
+ num_inference_steps=num_inference_steps,
142
+ strength=upscaler_strength,
143
+ generator=generator,
144
+ output_type="pil",
145
+ ).images
146
+ else:
147
+ images = pipe(
148
+ prompt=prompt,
149
+ negative_prompt=negative_prompt,
150
+ width=width,
151
+ height=height,
152
+ guidance_scale=guidance_scale,
153
+ num_inference_steps=num_inference_steps,
154
+ generator=generator,
155
+ output_type="pil",
156
+ ).images
157
 
158
+ if images and IS_COLAB:
159
+ for image in images:
160
+ filepath = utils.save_image(image, metadata, OUTPUT_DIR)
161
+ logger.info(f"Image saved as {filepath} with metadata")
162
+
163
+ # Add the generated image and metadata to the history
164
+ for image in images:
165
+ thumbnail = image.copy()
166
+ thumbnail.thumbnail((256, 256))
167
+ image_history.insert(0, {
168
+ "image": thumbnail,
169
  "prompt": prompt,
 
170
  "metadata": metadata
171
  })
172
 
173
+ return images, metadata, gr.update(value=image_history)
 
 
 
 
 
174
  except Exception as e:
175
  logger.exception(f"An error occurred: {e}")
176
  raise
 
180
  pipe.scheduler = backup_scheduler
181
  utils.free_memory()
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  if torch.cuda.is_available():
184
  pipe = load_pipeline(MODEL)
185
  logger.info("Loaded on Device!")
 
202
  )
203
  with gr.Group():
204
  with gr.Row():
205
+ with gr.Column(scale=2):
206
+ prompt = gr.Text(
207
+ label="Prompt",
208
+ show_label=False,
209
+ max_lines=5,
210
+ placeholder="Enter your prompt",
211
+ container=False,
212
+ )
213
+ run_button = gr.Button(
214
+ "Generate",
215
+ variant="primary",
216
+ scale=0
217
+ )
218
+ result = gr.Gallery(
219
+ label="Result",
220
+ columns=1,
221
+ preview=True,
222
+ show_label=False
223
+ )
224
+
225
+ with gr.Column(scale=1):
226
+ history = gr.Gallery(
227
+ label="Generation History",
228
+ show_label=True,
229
+ elem_id="history",
230
+ columns=2,
231
+ height=800,
232
+ )
233
 
234
  with gr.Accordion(label="Advanced Settings", open=False):
235
+ negative_prompt = gr.Text(
236
+ label="Negative Prompt",
237
+ max_lines=5,
238
+ placeholder="Enter a negative prompt",
239
+ value=""
240
+ )
241
+ aspect_ratio_selector = gr.Radio(
242
+ label="Aspect Ratio",
243
+ choices=config.aspect_ratios,
244
+ value="1024 x 1024",
245
+ container=True,
246
+ )
247
+ with gr.Group(visible=False) as custom_resolution:
248
+ with gr.Row():
249
+ custom_width = gr.Slider(
250
+ label="Width",
251
+ minimum=MIN_IMAGE_SIZE,
252
+ maximum=MAX_IMAGE_SIZE,
253
+ step=8,
254
+ value=1024,
255
+ )
256
+ custom_height = gr.Slider(
257
+ label="Height",
258
+ minimum=MIN_IMAGE_SIZE,
259
+ maximum=MAX_IMAGE_SIZE,
260
+ step=8,
261
+ value=1024,
262
+ )
263
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
264
+ with gr.Row() as upscaler_row:
265
+ upscaler_strength = gr.Slider(
266
+ label="Strength",
267
+ minimum=0,
268
+ maximum=1,
269
+ step=0.05,
270
+ value=0.55,
271
+ visible=False,
272
+ )
273
+ upscale_by = gr.Slider(
274
+ label="Upscale by",
275
+ minimum=1,
276
+ maximum=1.5,
277
+ step=0.1,
278
+ value=1.5,
279
+ visible=False,
280
+ )
281
 
282
+ sampler = gr.Dropdown(
283
+ label="Sampler",
284
+ choices=config.sampler_list,
285
+ interactive=True,
286
+ value="DPM++ 2M SDE Karras",
287
+ )
288
+ with gr.Row():
289
+ seed = gr.Slider(
290
+ label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
291
+ )
292
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
293
+ with gr.Group():
294
+ with gr.Row():
295
+ guidance_scale = gr.Slider(
296
+ label="Guidance scale",
297
+ minimum=1,
298
+ maximum=12,
299
+ step=0.1,
300
+ value=7.0,
301
+ )
302
+ num_inference_steps = gr.Slider(
303
+ label="Number of inference steps",
304
+ minimum=1,
305
+ maximum=50,
306
+ step=1,
307
+ value=28,
308
+ )
309
  with gr.Accordion(label="Generation Parameters", open=False):
310
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
 
311
  gr.Examples(
312
  examples=config.examples,
313
  inputs=prompt,
314
+ outputs=[result, gr_metadata, history],
315
  fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
316
  cache_examples=CACHE_EXAMPLES,
317
  )
318
+ use_upscaler.change(
319
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
320
+ inputs=use_upscaler,
321
+ outputs=[upscaler_strength, upscale_by],
322
+ queue=False,
323
+ api_name=False,
324
+ )
325
+ aspect_ratio_selector.change(
326
+ fn=lambda x: gr.update(visible=x == "Custom"),
327
+ inputs=aspect_ratio_selector,
328
+ outputs=custom_resolution,
329
+ queue=False,
330
+ api_name=False,
331
+ )
332
 
333
  inputs = [
334
  prompt,
 
354
  ).then(
355
  fn=generate,
356
  inputs=inputs,
357
+ outputs=[result, gr_metadata, history],
358
  api_name="run",
359
  )
360
  negative_prompt.submit(
 
366
  ).then(
367
  fn=generate,
368
  inputs=inputs,
369
+ outputs=[result, gr_metadata, history],
370
  api_name=False,
371
  )
372
  run_button.click(
 
378
  ).then(
379
  fn=generate,
380
  inputs=inputs,
381
+ outputs=[result, gr_metadata, history],
382
  api_name=False,
383
  )
384