gokaygokay commited on
Commit
b151eb9
1 Parent(s): e0f49e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -9
app.py CHANGED
@@ -39,21 +39,33 @@ download_file("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealES
39
  download_file("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth?download=true", "models/upscalers/", "RealESRGAN_x4.pth")
40
 
41
  # Download the model files
42
- ckpt_dir = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
 
43
 
44
  # Load the models
45
- vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
 
46
 
47
- pipe = StableDiffusionXLPipeline.from_pretrained(
48
- ckpt_dir,
49
- vae=vae,
50
  torch_dtype=torch.float16,
51
  use_safetensors=True,
52
  variant="fp16"
53
  )
54
- pipe = pipe.to("cuda")
 
 
 
 
 
 
 
 
 
55
 
56
- pipe.unet.set_attn_processor(AttnProcessor2_0())
 
57
 
58
  # Define samplers
59
  samplers = {
@@ -145,13 +157,16 @@ def upscale_image(image, scale):
145
  return image
146
 
147
  @spaces.GPU(duration=120)
148
- def generate_image(additional_positive_prompt, additional_negative_prompt, height, width, num_inference_steps,
149
  guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler, clip_skip,
150
  use_florence2, use_medium_enhancer, use_long_enhancer,
151
  use_positive_prefix, use_positive_suffix, use_negative_prefix, use_negative_suffix,
152
  use_upscaler, upscale_factor,
153
  input_image=None, progress=gr.Progress(track_tqdm=True)):
154
 
 
 
 
155
  if use_random_seed:
156
  seed = random.randint(0, 2**32 - 1)
157
  else:
@@ -242,8 +257,13 @@ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
242
  </p>
243
  """)
244
 
245
- with gr.Row():
246
  with gr.Column(scale=1):
 
 
 
 
 
247
  positive_prompt = gr.Textbox(label="Positive Prompt", placeholder="Add your positive prompt here")
248
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Add your negative prompt here")
249
 
@@ -301,6 +321,7 @@ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
301
  generate_btn.click(
302
  fn=generate_image,
303
  inputs=[
 
304
  positive_prompt, negative_prompt, height, width, num_inference_steps,
305
  guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler,
306
  clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer,
 
39
  download_file("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth?download=true", "models/upscalers/", "RealESRGAN_x4.pth")
40
 
41
  # Download the model files
42
+ ckpt_dir_pony = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
43
+ ckpt_dir_cyber = snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl")
44
 
45
  # Load the models
46
+ vae_pony = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_pony, "vae"), torch_dtype=torch.float16)
47
+ vae_cyber = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_cyber, "vae"), torch_dtype=torch.float16)
48
 
49
+ pipe_pony = StableDiffusionXLPipeline.from_pretrained(
50
+ ckpt_dir_pony,
51
+ vae=vae_pony,
52
  torch_dtype=torch.float16,
53
  use_safetensors=True,
54
  variant="fp16"
55
  )
56
+ pipe_cyber = StableDiffusionXLPipeline.from_pretrained(
57
+ ckpt_dir_cyber,
58
+ vae=vae_cyber,
59
+ torch_dtype=torch.float16,
60
+ use_safetensors=True,
61
+ variant="fp16"
62
+ )
63
+
64
+ pipe_pony = pipe_pony.to("cuda")
65
+ pipe_cyber = pipe_cyber.to("cuda")
66
 
67
+ pipe_pony.unet.set_attn_processor(AttnProcessor2_0())
68
+ pipe_cyber.unet.set_attn_processor(AttnProcessor2_0())
69
 
70
  # Define samplers
71
  samplers = {
 
157
  return image
158
 
159
  @spaces.GPU(duration=120)
160
+ def generate_image(model_choice, additional_positive_prompt, additional_negative_prompt, height, width, num_inference_steps,
161
  guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler, clip_skip,
162
  use_florence2, use_medium_enhancer, use_long_enhancer,
163
  use_positive_prefix, use_positive_suffix, use_negative_prefix, use_negative_suffix,
164
  use_upscaler, upscale_factor,
165
  input_image=None, progress=gr.Progress(track_tqdm=True)):
166
 
167
+ # Select the appropriate pipe based on the model choice
168
+ pipe = pipe_pony if model_choice == "Pony Realism v21" else pipe_cyber
169
+
170
  if use_random_seed:
171
  seed = random.randint(0, 2**32 - 1)
172
  else:
 
257
  </p>
258
  """)
259
 
260
+ with gr.Row():
261
  with gr.Column(scale=1):
262
+ model_choice = gr.Dropdown(
263
+ label="Model",
264
+ choices=["Pony Realism v21", "Cyberrealistic Pony v61"],
265
+ value="Pony Realism v21"
266
+ )
267
  positive_prompt = gr.Textbox(label="Positive Prompt", placeholder="Add your positive prompt here")
268
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Add your negative prompt here")
269
 
 
321
  generate_btn.click(
322
  fn=generate_image,
323
  inputs=[
324
+ model_choice, # Add this new input
325
  positive_prompt, negative_prompt, height, width, num_inference_steps,
326
  guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler,
327
  clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer,