Sergidev commited on
Commit
bf6dd2e
1 Parent(s): 16189d5
Files changed (1) hide show
  1. app.py +78 -206
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import gc
3
- import random
4
  import gradio as gr
5
  import numpy as np
6
  import torch
@@ -14,82 +13,24 @@ from datetime import datetime
14
  from diffusers.models import AutoencoderKL
15
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
16
 
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
-
20
- DESCRIPTION = "PonyDiffusion V6 XL"
21
- if not torch.cuda.is_available():
22
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
23
- IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
24
- HF_TOKEN = os.getenv("HF_TOKEN")
25
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
26
- MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
27
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
28
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
29
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
30
- OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
31
-
32
- MODEL = os.getenv(
33
- "MODEL",
34
- "https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
35
- )
36
-
37
- torch.backends.cudnn.deterministic = True
38
- torch.backends.cudnn.benchmark = False
39
-
40
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
-
42
- def load_pipeline(model_name):
43
- vae = AutoencoderKL.from_pretrained(
44
- "madebyollin/sdxl-vae-fp16-fix",
45
- torch_dtype=torch.float16,
46
- )
47
- pipeline = (
48
- StableDiffusionXLPipeline.from_single_file
49
- if MODEL.endswith(".safetensors")
50
- else StableDiffusionXLPipeline.from_pretrained
51
- )
52
-
53
- pipe = pipeline(
54
- model_name,
55
- vae=vae,
56
- torch_dtype=torch.float16,
57
- custom_pipeline="lpw_stable_diffusion_xl",
58
- use_safetensors=True,
59
- add_watermarker=False,
60
- use_auth_token=HF_TOKEN,
61
- variant="fp16",
62
- )
63
-
64
- pipe.to(device)
65
- return pipe
66
 
 
67
  def parse_json_parameters(json_str):
68
  try:
69
  params = json.loads(json_str)
 
 
 
 
70
  return params
71
  except json.JSONDecodeError:
72
- return None
73
-
74
- def apply_json_parameters(json_str):
75
- params = parse_json_parameters(json_str)
76
- if params:
77
- return (
78
- params.get("prompt", ""),
79
- params.get("negative_prompt", ""),
80
- params.get("seed", 0),
81
- params.get("width", 1024),
82
- params.get("height", 1024),
83
- params.get("guidance_scale", 7.0),
84
- params.get("num_inference_steps", 30),
85
- params.get("sampler", "DPM++ 2M SDE Karras"),
86
- params.get("aspect_ratio", "1024 x 1024"),
87
- params.get("use_upscaler", False),
88
- params.get("upscaler_strength", 0.55),
89
- params.get("upscale_by", 1.5),
90
- )
91
- return [gr.update()] * 12
92
 
 
 
93
  def generate(
94
  prompt: str,
95
  negative_prompt: str = "",
@@ -103,8 +44,23 @@ def generate(
103
  use_upscaler: bool = False,
104
  upscaler_strength: float = 0.55,
105
  upscale_by: float = 1.5,
 
106
  progress=gr.Progress(track_tqdm=True),
107
  ) -> Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  generator = utils.seed_everything(seed)
109
 
110
  width, height = utils.aspect_ratio_handler(
@@ -183,7 +139,7 @@ def generate(
183
  filepath = utils.save_image(image, metadata, OUTPUT_DIR)
184
  logger.info(f"Image saved as {filepath} with metadata")
185
 
186
- return images, json.dumps(metadata) # Return metadata as a JSON string
187
  except Exception as e:
188
  logger.exception(f"An error occurred: {e}")
189
  raise
@@ -193,21 +149,30 @@ def generate(
193
  pipe.scheduler = backup_scheduler
194
  utils.free_memory()
195
 
196
- def get_random_prompt():
197
- anime_characters = [
198
- "Naruto Uzumaki", "Monkey D. Luffy", "Goku", "Eren Yeager", "Light Yagami",
199
- "Lelouch Lamperouge", "Edward Elric", "Levi Ackerman", "Spike Spiegel",
200
- "Sakura Haruno", "Mikasa Ackerman", "Asuka Langley Soryu", "Rem", "Megumin",
201
- "Violet Evergarden"
202
- ]
203
- styles = ["pixel art", "stylized anime", "digital art", "watercolor", "sketch"]
204
- scores = ["score_9", "score_8_up", "score_7_up"]
205
-
206
- character = random.choice(anime_characters)
207
- style = random.choice(styles)
208
- score = ", ".join(random.sample(scores, k=3))
209
-
210
- return f"{score}, {character}, {style}, show accurate"
 
 
 
 
 
 
 
 
 
211
 
212
  if torch.cuda.is_available():
213
  pipe = load_pipeline(MODEL)
@@ -215,59 +180,7 @@ if torch.cuda.is_available():
215
  else:
216
  pipe = None
217
 
218
- # Define the JavaScript code as a string
219
- js_code = """
220
- <script>
221
- document.addEventListener('DOMContentLoaded', (event) => {
222
- const historyDropdown = document.getElementById('history-dropdown');
223
- const resultGallery = document.querySelector('.gallery');
224
-
225
- if (historyDropdown && resultGallery) {
226
- const observer = new MutationObserver((mutations) => {
227
- mutations.forEach((mutation) => {
228
- if (mutation.type === 'childList' && mutation.addedNodes.length > 0) {
229
- const newImage = mutation.addedNodes[0];
230
- if (newImage.tagName === 'IMG') {
231
- updateHistory(newImage.src);
232
- }
233
- }
234
- });
235
- });
236
-
237
- observer.observe(resultGallery, { childList: true });
238
-
239
- function updateHistory(imageSrc) {
240
- const prompt = document.querySelector('#prompt textarea').value;
241
- const option = document.createElement('option');
242
- option.value = prompt;
243
- option.textContent = prompt;
244
- option.setAttribute('data-image', imageSrc);
245
-
246
- historyDropdown.insertBefore(option, historyDropdown.firstChild);
247
-
248
- if (historyDropdown.children.length > 10) {
249
- historyDropdown.removeChild(historyDropdown.lastChild);
250
- }
251
- }
252
-
253
- historyDropdown.addEventListener('change', (event) => {
254
- const selectedOption = event.target.selectedOptions[0];
255
- const imageSrc = selectedOption.getAttribute('data-image');
256
- if (imageSrc) {
257
- const img = document.createElement('img');
258
- img.src = imageSrc;
259
- resultGallery.innerHTML = '';
260
- resultGallery.appendChild(img);
261
- }
262
- });
263
- }
264
- });
265
- </script>
266
- """
267
-
268
  with gr.Blocks(css="style.css") as demo:
269
- gr.HTML(js_code) # Add the JavaScript code to the interface
270
-
271
  title = gr.HTML(
272
  f"""<h1><span>{DESCRIPTION}</span></h1>""",
273
  elem_id="title",
@@ -376,34 +289,24 @@ with gr.Blocks(css="style.css") as demo:
376
  step=1,
377
  value=28,
378
  )
379
- with gr.Accordion(label="JSON Parameters", open=False):
380
- json_input = gr.TextArea(label="Input JSON parameters")
381
- apply_json_button = gr.Button("Apply JSON Parameters")
382
-
383
- with gr.Row():
384
- clear_button = gr.Button("Clear All")
385
- random_prompt_button = gr.Button("Random Prompt")
386
-
387
- history = gr.State([]) # Add a state component to store history
388
- history_dropdown = gr.Dropdown(label="Generation History", choices=[], interactive=True, elem_id="history-dropdown")
389
-
390
  with gr.Accordion(label="Generation Parameters", open=False):
391
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
392
-
393
- def update_history(images, metadata, history):
394
- if images:
395
- new_entry = {"prompt": json.loads(metadata)["prompt"], "image": images[0]}
396
- history = [new_entry] + history[:9] # Keep only the last 10 entries
397
- return gr.update(choices=[h["prompt"] for h in history]), history
 
398
 
399
  gr.Examples(
400
  examples=config.examples,
401
  inputs=prompt,
402
  outputs=[result, gr_metadata],
403
- fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
404
  cache_examples=CACHE_EXAMPLES,
405
  )
406
-
407
  use_upscaler.change(
408
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
409
  inputs=use_upscaler,
@@ -432,6 +335,7 @@ with gr.Blocks(css="style.css") as demo:
432
  use_upscaler,
433
  upscaler_strength,
434
  upscale_by,
 
435
  ]
436
 
437
  prompt.submit(
@@ -441,16 +345,11 @@ with gr.Blocks(css="style.css") as demo:
441
  queue=False,
442
  api_name=False,
443
  ).then(
444
- fn=generate,
445
  inputs=inputs,
446
- outputs=[result, gr_metadata],
447
  api_name="run",
448
- ).then(
449
- fn=update_history,
450
- inputs=[result, gr_metadata, history],
451
- outputs=[history_dropdown, history],
452
  )
453
-
454
  negative_prompt.submit(
455
  fn=utils.randomize_seed_fn,
456
  inputs=[seed, randomize_seed],
@@ -458,16 +357,11 @@ with gr.Blocks(css="style.css") as demo:
458
  queue=False,
459
  api_name=False,
460
  ).then(
461
- fn=generate,
462
  inputs=inputs,
463
- outputs=[result, gr_metadata],
464
  api_name=False,
465
- ).then(
466
- fn=update_history,
467
- inputs=[result, gr_metadata, history],
468
- outputs=[history_dropdown, history],
469
  )
470
-
471
  run_button.click(
472
  fn=utils.randomize_seed_fn,
473
  inputs=[seed, randomize_seed],
@@ -475,47 +369,25 @@ with gr.Blocks(css="style.css") as demo:
475
  queue=False,
476
  api_name=False,
477
  ).then(
478
- fn=generate,
479
  inputs=inputs,
480
- outputs=[result, gr_metadata],
481
  api_name=False,
482
- ).then(
483
- fn=update_history,
484
- inputs=[result, gr_metadata, history],
485
- outputs=[history_dropdown, history],
486
- )
487
-
488
- apply_json_button.click(
489
- fn=apply_json_parameters,
490
- inputs=json_input,
491
- outputs=[prompt, negative_prompt, seed, custom_width, custom_height,
492
- guidance_scale, num_inference_steps, sampler,
493
- aspect_ratio_selector, use_upscaler, upscaler_strength, upscale_by]
494
- )
495
-
496
- clear_button.click(
497
- fn=lambda: (gr.update(value=""), gr.update(value=""), gr.update(value=0),
498
- gr.update(value=1024), gr.update(value=1024),
499
- gr.update(value=7.0), gr.update(value=30),
500
- gr.update(value="DPM++ 2M SDE Karras"),
501
- gr.update(value="1024 x 1024"), gr.update(value=False),
502
- gr.update(value=0.55), gr.update(value=1.5)),
503
- inputs=[],
504
- outputs=[prompt, negative_prompt, seed, custom_width, custom_height,
505
- guidance_scale, num_inference_steps, sampler,
506
- aspect_ratio_selector, use_upscaler, upscaler_strength, upscale_by]
507
  )
508
 
509
- random_prompt_button.click(
510
- fn=get_random_prompt,
511
- inputs=[],
512
- outputs=prompt
 
 
513
  )
514
 
 
515
  history_dropdown.change(
516
- fn=lambda x, history: next((h["prompt"] for h in history if h["prompt"] == x), ""),
517
- inputs=[history_dropdown, history],
518
- outputs=prompt
519
  )
520
 
521
- demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
 
1
  import os
2
  import gc
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
 
13
  from diffusers.models import AutoencoderKL
14
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
 
16
+ # ... (keep the existing imports and configurations)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Add a new function to parse and validate JSON input
19
  def parse_json_parameters(json_str):
20
  try:
21
  params = json.loads(json_str)
22
+ required_keys = ['prompt', 'negative_prompt', 'seed', 'width', 'height', 'guidance_scale', 'num_inference_steps', 'sampler']
23
+ for key in required_keys:
24
+ if key not in params:
25
+ raise ValueError(f"Missing required key: {key}")
26
  return params
27
  except json.JSONDecodeError:
28
+ raise ValueError("Invalid JSON format")
29
+ except Exception as e:
30
+ raise ValueError(f"Error parsing JSON: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Modify the generate function to accept JSON parameters
33
+ @spaces.GPU
34
  def generate(
35
  prompt: str,
36
  negative_prompt: str = "",
 
44
  use_upscaler: bool = False,
45
  upscaler_strength: float = 0.55,
46
  upscale_by: float = 1.5,
47
+ json_params: str = "",
48
  progress=gr.Progress(track_tqdm=True),
49
  ) -> Image:
50
+ if json_params:
51
+ try:
52
+ params = parse_json_parameters(json_params)
53
+ prompt = params['prompt']
54
+ negative_prompt = params['negative_prompt']
55
+ seed = params['seed']
56
+ custom_width = params['width']
57
+ custom_height = params['height']
58
+ guidance_scale = params['guidance_scale']
59
+ num_inference_steps = params['num_inference_steps']
60
+ sampler = params['sampler']
61
+ except ValueError as e:
62
+ raise gr.Error(str(e))
63
+
64
  generator = utils.seed_everything(seed)
65
 
66
  width, height = utils.aspect_ratio_handler(
 
139
  filepath = utils.save_image(image, metadata, OUTPUT_DIR)
140
  logger.info(f"Image saved as {filepath} with metadata")
141
 
142
+ return images, metadata
143
  except Exception as e:
144
  logger.exception(f"An error occurred: {e}")
145
  raise
 
149
  pipe.scheduler = backup_scheduler
150
  utils.free_memory()
151
 
152
+ # Initialize an empty list to store the generation history
153
+ generation_history = []
154
+
155
+ # Function to update the history dropdown
156
+ def update_history_dropdown():
157
+ return gr.Dropdown.update(choices=[f"{item['prompt']} ({item['timestamp']})" for item in generation_history])
158
+
159
+ # Modify the generate function to add results to the history
160
+ def generate_and_update_history(*args, **kwargs):
161
+ images, metadata = generate(*args, **kwargs)
162
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
163
+ generation_history.insert(0, {"prompt": metadata["prompt"], "timestamp": timestamp, "image": images[0], "metadata": metadata})
164
+ if len(generation_history) > 10: # Limit history to 10 items
165
+ generation_history.pop()
166
+ return images, metadata, update_history_dropdown()
167
+
168
+ # Function to display selected history item
169
+ def display_history_item(selected_item):
170
+ if not selected_item:
171
+ return None, None
172
+ for item in generation_history:
173
+ if f"{item['prompt']} ({item['timestamp']})" == selected_item:
174
+ return item['image'], json.dumps(item['metadata'], indent=2)
175
+ return None, None
176
 
177
  if torch.cuda.is_available():
178
  pipe = load_pipeline(MODEL)
 
180
  else:
181
  pipe = None
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  with gr.Blocks(css="style.css") as demo:
 
 
184
  title = gr.HTML(
185
  f"""<h1><span>{DESCRIPTION}</span></h1>""",
186
  elem_id="title",
 
289
  step=1,
290
  value=28,
291
  )
 
 
 
 
 
 
 
 
 
 
 
292
  with gr.Accordion(label="Generation Parameters", open=False):
293
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
294
+ json_input = gr.TextArea(label="Edit/Paste JSON Parameters", placeholder="Paste or edit JSON parameters here")
295
+ generate_from_json = gr.Button("Generate from JSON")
296
+
297
+ # Add history dropdown
298
+ history_dropdown = gr.Dropdown(label="Generation History", choices=[], interactive=True)
299
+ history_image = gr.Image(label="Selected Image", interactive=False)
300
+ history_metadata = gr.JSON(label="Selected Metadata", show_label=False)
301
 
302
  gr.Examples(
303
  examples=config.examples,
304
  inputs=prompt,
305
  outputs=[result, gr_metadata],
306
+ fn=lambda *args, **kwargs: generate_and_update_history(*args, use_upscaler=True, **kwargs),
307
  cache_examples=CACHE_EXAMPLES,
308
  )
309
+
310
  use_upscaler.change(
311
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
312
  inputs=use_upscaler,
 
335
  use_upscaler,
336
  upscaler_strength,
337
  upscale_by,
338
+ json_input, # Add JSON input to the list of inputs
339
  ]
340
 
341
  prompt.submit(
 
345
  queue=False,
346
  api_name=False,
347
  ).then(
348
+ fn=generate_and_update_history, # Use the new function
349
  inputs=inputs,
350
+ outputs=[result, gr_metadata, history_dropdown], # Add history_dropdown to outputs
351
  api_name="run",
 
 
 
 
352
  )
 
353
  negative_prompt.submit(
354
  fn=utils.randomize_seed_fn,
355
  inputs=[seed, randomize_seed],
 
357
  queue=False,
358
  api_name=False,
359
  ).then(
360
+ fn=generate_and_update_history, # Use the new function
361
  inputs=inputs,
362
+ outputs=[result, gr_metadata, history_dropdown], # Add history_dropdown to outputs
363
  api_name=False,
 
 
 
 
364
  )
 
365
  run_button.click(
366
  fn=utils.randomize_seed_fn,
367
  inputs=[seed, randomize_seed],
 
369
  queue=False,
370
  api_name=False,
371
  ).then(
372
+ fn=generate_and_update_history, # Use the new function
373
  inputs=inputs,
374
+ outputs=[result, gr_metadata, history_dropdown], # Add history_dropdown to outputs
375
  api_name=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  )
377
 
378
+ # Add event handler for generate_from_json button
379
+ generate_from_json.click(
380
+ fn=generate_and_update_history,
381
+ inputs=inputs,
382
+ outputs=[result, gr_metadata, history_dropdown],
383
+ api_name=False,
384
  )
385
 
386
+ # Add event handler for history_dropdown
387
  history_dropdown.change(
388
+ fn=display_history_item,
389
+ inputs=[history_dropdown],
390
+ outputs=[history_image, history_metadata],
391
  )
392
 
393
+ demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)