AP123 commited on
Commit
cf296f0
β€’
1 Parent(s): 633ca23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -13
app.py CHANGED
@@ -5,19 +5,23 @@ from PIL import Image
5
  import random
6
  from diffusers import DiffusionPipeline
7
 
8
- # Initialize DiffusionPipeline with LoRA weights
9
  pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
10
  pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora")
 
 
 
11
 
12
  def text_to_image(prompt):
13
- generated_img = pipeline(prompt)
14
- return generated_img
 
 
 
 
15
 
16
  def create_cereal_box(input_image):
17
- cover_img = input_image # This should already be a PIL Image
18
- template_img = Image.open('CerealBoxMaker/template.jpeg')
19
-
20
- # Cereal box creation logic
21
  scaling_factor = 1.5
22
  rect_height = int(template_img.height * 0.32)
23
  new_width = int(rect_height * 0.70)
@@ -34,15 +38,32 @@ def create_cereal_box(input_image):
34
  template_copy = template_img.copy()
35
  template_copy.paste(cover_resized_scaled, left_position)
36
  template_copy.paste(cover_resized_scaled, right_position)
37
-
38
- # Convert to a numpy array for Gradio output
39
  template_copy_array = np.array(template_copy)
40
  return template_copy_array
41
 
42
  def combined_function(prompt):
43
- generated_img = text_to_image(prompt)
44
- final_img = create_cereal_box(generated_img)
45
  return final_img
46
 
47
- # Create Gradio Interface
48
- gr.Interface(fn=combined_function, inputs="text", outputs="image").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import random
6
  from diffusers import DiffusionPipeline
7
 
 
8
  pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
9
  pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora")
10
+ pipeline.to("cuda:0")
11
+
12
+ MAX_SEED = np.iinfo(np.int32).max
13
 
14
  def text_to_image(prompt):
15
+ seed = random.randint(0, MAX_SEED)
16
+ negative_prompt = "ugly, blurry, nsfw, gore, blood"
17
+ output = pipeline(prompt=prompt, negative_prompt=negative_prompt, width=1024, height=1024, guidance_scale=7.0, num_inference_steps=25, generator=torch.Generator().manual_seed(seed))
18
+ generated_img = output.images[0]
19
+ generated_img_array = np.array(generated_img)
20
+ return generated_img_array
21
 
22
  def create_cereal_box(input_image):
23
+ cover_img = Image.fromarray(input_image.astype('uint8'), 'RGB')
24
+ template_img = Image.open('/content/866b9b8f50b50879120be0b87dfd6050.jpg')
 
 
25
  scaling_factor = 1.5
26
  rect_height = int(template_img.height * 0.32)
27
  new_width = int(rect_height * 0.70)
 
38
  template_copy = template_img.copy()
39
  template_copy.paste(cover_resized_scaled, left_position)
40
  template_copy.paste(cover_resized_scaled, right_position)
 
 
41
  template_copy_array = np.array(template_copy)
42
  return template_copy_array
43
 
44
  def combined_function(prompt):
45
+ generated_img_array = text_to_image(prompt)
46
+ final_img = create_cereal_box(generated_img_array)
47
  return final_img
48
 
49
+ with gr.Blocks() as app:
50
+ gr.HTML("<div style='text-align: center;'><h1>Cereal Box Maker πŸ₯£</h1></div>")
51
+ gr.HTML("<div style='text-align: center;'><p>This application uses StableDiffusion XL to create any cereal box you could ever imagine!</p></div>")
52
+ gr.HTML("<div style='text-align: center;'><h3>Instructions:</h3><ol><li>Describe the cereal box you want to create and hit generate!</li><li>Print it out, cut the outside, fold the lines, and then tape!</li></ol></div>")
53
+ gr.HTML("<div style='text-align: center;'><p>A space by AP 🐧, follow me on <a href='https://twitter.com/angrypenguinPNG'>Twitter</a>! H/T to OstrisAI <a href='https://twitter.com/ostrisai'>Twitter</a> for their Cereal Box LoRA!</p></div>")
54
+
55
+ with gr.Row():
56
+ textbox = gr.Textbox(label="Describe your cereal box: Ex: 'Avengers Cereal'")
57
+ btn_generate = gr.Button("Generate", label="Generate")
58
+
59
+ with gr.Row():
60
+ output_img = gr.Image(label="Your Custom Cereal Box")
61
+
62
+ btn_generate.click(
63
+ combined_function,
64
+ inputs=[textbox],
65
+ outputs=[output_img]
66
+ )
67
+
68
+ app.queue(concurrency_count=4, max_size=20, api_open=False)
69
+ app.launch(debug=True)