gokaygokay commited on
Commit
4d122d3
1 Parent(s): d8861a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -5
app.py CHANGED
@@ -9,10 +9,33 @@ from diffusers.models.attention_processor import AttnProcessor2_0
9
  import gradio as gr
10
  from PIL import Image
11
  from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
 
 
12
 
13
  import subprocess
14
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Download the model files
17
  ckpt_dir = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
18
 
@@ -33,7 +56,6 @@ pipe.unet.set_attn_processor(AttnProcessor2_0())
33
  # Define samplers
34
  samplers = {
35
  "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
36
- "DPM++ 2M": DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="dpmsolver++", use_karras_sigmas=True),
37
  "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
38
  }
39
 
@@ -51,6 +73,12 @@ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base',
51
  enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
52
  enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
53
 
 
 
 
 
 
 
54
  # Florence caption function
55
  def florence_caption(image):
56
  # Convert image to PIL if it's not already
@@ -85,11 +113,21 @@ def enhance_prompt(input_prompt, model_choice):
85
 
86
  return enhanced_text
87
 
 
 
 
 
 
 
 
 
 
88
  @spaces.GPU(duration=120)
89
  def generate_image(additional_positive_prompt, additional_negative_prompt, height, width, num_inference_steps,
90
  guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler, clip_skip,
91
  use_florence2, use_medium_enhancer, use_long_enhancer,
92
  use_positive_prefix, use_positive_suffix, use_negative_prefix, use_negative_suffix,
 
93
  input_image=None, progress=gr.Progress(track_tqdm=True)):
94
 
95
  if use_random_seed:
@@ -138,7 +176,7 @@ def generate_image(additional_positive_prompt, additional_negative_prompt, heigh
138
  full_negative_prompt += f", {DEFAULT_NEGATIVE_SUFFIX}"
139
 
140
  try:
141
- image = pipe(
142
  prompt=full_positive_prompt,
143
  negative_prompt=full_negative_prompt,
144
  height=height,
@@ -148,7 +186,15 @@ def generate_image(additional_positive_prompt, additional_negative_prompt, heigh
148
  num_images_per_prompt=num_images_per_prompt,
149
  generator=torch.Generator(pipe.device).manual_seed(seed)
150
  ).images
151
- return image, seed, full_positive_prompt, full_negative_prompt
 
 
 
 
 
 
 
 
152
  except Exception as e:
153
  print(f"Error during image generation: {str(e)}")
154
  return None, seed, full_positive_prompt, full_negative_prompt
@@ -188,6 +234,10 @@ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
188
  use_medium_enhancer = gr.Checkbox(label="Use Medium Prompt Enhancer", value=False)
189
  use_long_enhancer = gr.Checkbox(label="Use Long Prompt Enhancer", value=False)
190
 
 
 
 
 
191
  generate_btn = gr.Button("Generate Image")
192
 
193
  with gr.Accordion("Prefix and Suffix Settings", open=True):
@@ -211,8 +261,6 @@ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
211
  value=True,
212
  info=f"Suffix: {DEFAULT_NEGATIVE_SUFFIX}"
213
  )
214
-
215
-
216
 
217
  with gr.Column(scale=1):
218
  output_gallery = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
@@ -227,6 +275,7 @@ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
227
  guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler,
228
  clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer,
229
  use_positive_prefix, use_positive_suffix, use_negative_prefix, use_negative_suffix,
 
230
  input_image
231
  ],
232
  outputs=[output_gallery, seed_used, full_positive_prompt_used, full_negative_prompt_used]
 
9
  import gradio as gr
10
  from PIL import Image
11
  from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
12
+ import requests
13
+ from RealESRGAN import RealESRGAN
14
 
15
  import subprocess
16
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
17
 
18
+ def download_file(url, folder_path, filename):
19
+ if not os.path.exists(folder_path):
20
+ os.makedirs(folder_path)
21
+ file_path = os.path.join(folder_path, filename)
22
+
23
+ if os.path.isfile(file_path):
24
+ print(f"File already exists: {file_path}")
25
+ else:
26
+ response = requests.get(url, stream=True)
27
+ if response.status_code == 200:
28
+ with open(file_path, 'wb') as file:
29
+ for chunk in response.iter_content(chunk_size=1024):
30
+ file.write(chunk)
31
+ print(f"File successfully downloaded and saved: {file_path}")
32
+ else:
33
+ print(f"Error downloading the file. Status code: {response.status_code}")
34
+
35
+ # Download ESRGAN models
36
+ download_file("https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x2.pth", "models/upscalers/", "RealESRGAN_x2.pth")
37
+ download_file("https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4.pth", "models/upscalers/", "RealESRGAN_x4.pth")
38
+
39
  # Download the model files
40
  ckpt_dir = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
41
 
 
56
  # Define samplers
57
  samplers = {
58
  "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
 
59
  "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
60
  }
61
 
 
73
  enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
74
  enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
75
 
76
+ # Initialize ESRGAN models
77
+ realesrgan_x2 = RealESRGAN(device, scale=2)
78
+ realesrgan_x2.load_weights('models/upscalers/RealESRGAN_x2.pth', download=False)
79
+ realesrgan_x4 = RealESRGAN(device, scale=4)
80
+ realesrgan_x4.load_weights('models/upscalers/RealESRGAN_x4.pth', download=False)
81
+
82
  # Florence caption function
83
  def florence_caption(image):
84
  # Convert image to PIL if it's not already
 
113
 
114
  return enhanced_text
115
 
116
+ # Upscale function
117
+ def upscale_image(image, scale):
118
+ if scale == 2:
119
+ return realesrgan_x2.predict(image)
120
+ elif scale == 4:
121
+ return realesrgan_x4.predict(image)
122
+ else:
123
+ return image
124
+
125
  @spaces.GPU(duration=120)
126
  def generate_image(additional_positive_prompt, additional_negative_prompt, height, width, num_inference_steps,
127
  guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler, clip_skip,
128
  use_florence2, use_medium_enhancer, use_long_enhancer,
129
  use_positive_prefix, use_positive_suffix, use_negative_prefix, use_negative_suffix,
130
+ use_upscaler, upscale_factor,
131
  input_image=None, progress=gr.Progress(track_tqdm=True)):
132
 
133
  if use_random_seed:
 
176
  full_negative_prompt += f", {DEFAULT_NEGATIVE_SUFFIX}"
177
 
178
  try:
179
+ images = pipe(
180
  prompt=full_positive_prompt,
181
  negative_prompt=full_negative_prompt,
182
  height=height,
 
186
  num_images_per_prompt=num_images_per_prompt,
187
  generator=torch.Generator(pipe.device).manual_seed(seed)
188
  ).images
189
+
190
+ if use_upscaler:
191
+ upscaled_images = []
192
+ for img in images:
193
+ upscaled_img = upscale_image(img, upscale_factor)
194
+ upscaled_images.append(upscaled_img)
195
+ images = upscaled_images
196
+
197
+ return images, seed, full_positive_prompt, full_negative_prompt
198
  except Exception as e:
199
  print(f"Error during image generation: {str(e)}")
200
  return None, seed, full_positive_prompt, full_negative_prompt
 
234
  use_medium_enhancer = gr.Checkbox(label="Use Medium Prompt Enhancer", value=False)
235
  use_long_enhancer = gr.Checkbox(label="Use Long Prompt Enhancer", value=False)
236
 
237
+ with gr.Accordion("Upscaler Settings", open=False):
238
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
239
+ upscale_factor = gr.Radio(label="Upscale Factor", choices=[2, 4], value=2)
240
+
241
  generate_btn = gr.Button("Generate Image")
242
 
243
  with gr.Accordion("Prefix and Suffix Settings", open=True):
 
261
  value=True,
262
  info=f"Suffix: {DEFAULT_NEGATIVE_SUFFIX}"
263
  )
 
 
264
 
265
  with gr.Column(scale=1):
266
  output_gallery = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
 
275
  guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler,
276
  clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer,
277
  use_positive_prefix, use_positive_suffix, use_negative_prefix, use_negative_suffix,
278
+ use_upscaler, upscale_factor,
279
  input_image
280
  ],
281
  outputs=[output_gallery, seed_used, full_positive_prompt_used, full_negative_prompt_used]