FridaKahlosArtCenter commited on
Commit
90e022a
1 Parent(s): 2740cfb

colab working script

Browse files
app.py CHANGED
@@ -4,46 +4,131 @@ from PIL import Image, ImageDraw, ImageFont
4
  import requests
5
  from io import BytesIO
6
  import gradio as gr
 
 
7
 
8
- # log gpu availabilitu
9
  print(f"Is CUDA available: {torch.cuda.is_available()}")
10
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
11
 
12
 
13
  def image_to_template(generated_image, logo, button_text, punchline, theme_color):
14
- # Resize logo if needed
15
- logo = logo.resize((100, 100)) # Example size, adjust as needed
16
-
17
- # Create a blank canvas with extra space for logo, punchline, and button
18
- canvas_width = max(generated_image.width, logo.width) * 2
19
- canvas_height = generated_image.height + logo.height + 100
20
- canvas = Image.new('RGB', (canvas_width, canvas_height), 'white')
21
-
22
- # Paste the logo and the generated image onto the canvas
23
- canvas.paste(logo, (10, 10)) # Adjust position as needed
24
- canvas.paste(generated_image, (0, logo.height + 20))
25
-
26
- # Add punchline and button
27
- draw = ImageDraw.Draw(canvas)
28
- font = ImageFont.load_default() # Or use a custom font
29
- text_color = theme_color
30
-
31
- # Draw punchline
32
- draw.text((10, logo.height + generated_image.height + 30), punchline, fill=text_color, font=font)
33
-
34
- # Draw button
35
- button_position = (10, logo.height + generated_image.height + 60) # Adjust as needed
36
- draw.rectangle([button_position, (canvas_width - 10, canvas_height - 10)], outline=theme_color, fill=text_color)
37
- draw.text(button_position, button_text, font=font)
38
-
39
- return canvas
40
-
41
- def generate_template(initial_image, logo, prompt, button_text, punchline, image_color, theme_color):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  pipeline = AutoPipelineForImage2Image.from_pretrained(
43
- "./models/kandinsky-2-2-decoder",
44
- torch_dtype=torch.float16,
45
- variant="fp16",
46
- use_safetensors=True
47
  )
48
 
49
  # pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
@@ -53,34 +138,53 @@ def generate_template(initial_image, logo, prompt, button_text, punchline, image
53
  negative_prompt = "low quality, bad quality, blurry, unprofessional"
54
 
55
  generated_image = pipeline(
56
- prompt=prompt,
57
  negative_prompt=negative_prompt,
58
- image=initial_image,
59
- height=256,
60
- width=256).images[0]
 
61
 
62
- template_image = image_to_template(generated_image, logo, button_text, punchline, theme_color)
 
 
 
 
 
 
 
63
 
64
  return template_image
65
 
 
66
  # Set up Gradio interface
67
  iface = gr.Interface(
68
  fn=generate_template,
69
- inputs=[gr.inputs.Image(type="pil", label="Initial Image"),
70
- gr.inputs.Image(type="pil", label="Logo"),
71
- gr.inputs.Textbox(label="Prompt"),
72
- gr.inputs.Textbox(label="Button Text"),
73
- gr.inputs.Textbox(label="Punchline"),
74
- gr.inputs.ColorPicker(label="Image Color"),
75
- gr.inputs.ColorPicker(label="Theme Color")],
76
- outputs=[gr.outputs.Image(type="pil")],
 
 
77
  title="Ad Template Generation Using Diffusion Models Demo",
78
  description="Generate ad template based on your inputs using a trained model.",
79
  concurrency_limit=2,
80
- # examples=[
81
- # []
82
- # ]
 
 
 
 
 
 
 
 
83
  )
84
 
85
  # Run the interface
86
- iface.launch()
 
4
  import requests
5
  from io import BytesIO
6
  import gradio as gr
7
+ import gc
8
+ import textwrap
9
 
10
+ # log gpu availability
11
  print(f"Is CUDA available: {torch.cuda.is_available()}")
12
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
13
 
14
 
15
  def image_to_template(generated_image, logo, button_text, punchline, theme_color):
16
+ template_width = 540
17
+ button_font_size = 10
18
+ punchline_font_size = 30
19
+ decoration_height = 10
20
+ margin = 20
21
+ # wrap punchline text
22
+ punchline = textwrap.wrap(punchline, width=35)
23
+ n_of_lines_punchline = len(punchline)
24
+
25
+ generated_image = generated_image.convert("RGBA")
26
+ logo = logo.convert("RGBA")
27
+
28
+ # image shape
29
+ image_width = template_width // 2
30
+ image_height = image_width * generated_image.height // generated_image.width
31
+ image_shape = (image_width, image_height)
32
+
33
+ # logo shape
34
+ logo_width = image_width // 3
35
+ logo_height = logo_width * logo.height // logo.width
36
+ logo_shape = (logo_width, logo_height)
37
+
38
+ # Define fonts
39
+ button_font = ImageFont.truetype("Montserrat-Bold.ttf", button_font_size)
40
+ punchline_font = ImageFont.truetype("Montserrat-Bold.ttf", punchline_font_size)
41
+
42
+ # button shape
43
+ button_width = template_width // 3
44
+ button_height = button_font_size * 3
45
+
46
+ # template height calculation
47
+ template_height = (
48
+ image_height
49
+ + logo_height
50
+ + button_height
51
+ + n_of_lines_punchline * punchline_font_size
52
+ + (5 * margin)
53
+ + (2 * decoration_height)
54
+ )
55
+
56
+ # Calculate positions for the centered layout
57
+ logo_pos = ((template_width - logo_width) // 2, margin + decoration_height)
58
+ image_pos = (
59
+ (template_width - image_width) // 2,
60
+ logo_pos[1] + logo_height + margin,
61
+ )
62
+
63
+ # Decoration positions
64
+ top_decoration_pos = [
65
+ margin,
66
+ -decoration_height // 2,
67
+ template_width - margin,
68
+ decoration_height // 2,
69
+ ]
70
+ bottom_decoration_pos = [
71
+ margin,
72
+ template_height - decoration_height // 2,
73
+ template_width - margin,
74
+ template_height + decoration_height // 2,
75
+ ]
76
+
77
+ # Generate Components
78
+ generated_image.thumbnail(image_shape, Image.ANTIALIAS)
79
+ logo.thumbnail(logo_shape, Image.ANTIALIAS)
80
+ background = Image.new("RGBA", (template_width, template_height), "WHITE")
81
+ # round the corners of generated image
82
+ mask = Image.new("L", generated_image.size, 0)
83
+ draw = ImageDraw.Draw(mask)
84
+ draw.rounded_rectangle((0, 0) + generated_image.size, 20, fill=255)
85
+ generated_image.putalpha(mask)
86
+ # Paste the logo and the generated image onto the background
87
+ background.paste(logo, logo_pos, logo)
88
+ background.paste(generated_image, image_pos, generated_image)
89
+ # Draw the decorations, punchline, and button
90
+ draw = ImageDraw.Draw(background)
91
+ # Decorations on top and bottom
92
+ draw.rounded_rectangle(bottom_decoration_pos, radius=20, fill=theme_color)
93
+ draw.rounded_rectangle(top_decoration_pos, radius=20, fill=theme_color)
94
+ # Punchline text
95
+ text_heights = []
96
+ for line in punchline:
97
+ text_width, text_height = draw.textsize(line, font=punchline_font)
98
+ punchline_pos = (
99
+ (template_width - text_width) // 2,
100
+ image_pos[1] + generated_image.height + margin + sum(text_heights),
101
+ )
102
+ draw.text(punchline_pos, line, fill=theme_color, font=punchline_font)
103
+ text_heights.append(text_height)
104
+
105
+ # Button with rounded corners
106
+ button_text_width, button_text_height = draw.textsize(button_text, font=button_font)
107
+ button_shape = [
108
+ ((template_width - button_width) // 2, punchline_pos[1] + text_height + margin),
109
+ (
110
+ (template_width + button_width) // 2,
111
+ punchline_pos[1] + text_height + margin + button_height,
112
+ ),
113
+ ]
114
+ draw.rounded_rectangle(button_shape, radius=20, fill=theme_color)
115
+ # Button text
116
+ button_text_pos = (
117
+ (template_width - button_text_width) // 2,
118
+ button_shape[0][1] + (button_height - button_text_height) // 2,
119
+ )
120
+ draw.text(button_text_pos, button_text, fill="white", font=button_font)
121
+
122
+ return background
123
+
124
+
125
+ def generate_template(
126
+ initial_image, logo, prompt, button_text, punchline, image_color, theme_color
127
+ ):
128
  pipeline = AutoPipelineForImage2Image.from_pretrained(
129
+ "./models/kandinsky-2-2-decoder",
130
+ torch_dtype=torch.float16,
131
+ use_safetensors=True,
 
132
  )
133
 
134
  # pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
 
138
  negative_prompt = "low quality, bad quality, blurry, unprofessional"
139
 
140
  generated_image = pipeline(
141
+ prompt=prompt,
142
  negative_prompt=negative_prompt,
143
+ image=initial_image,
144
+ height=256,
145
+ width=256,
146
+ ).images[0]
147
 
148
+ template_image = image_to_template(
149
+ generated_image, logo, button_text, punchline, theme_color
150
+ )
151
+
152
+ # free cpu and gpu memory
153
+ del pipeline
154
+ gc.collect()
155
+ torch.cuda.empty_cache()
156
 
157
  return template_image
158
 
159
+
160
  # Set up Gradio interface
161
  iface = gr.Interface(
162
  fn=generate_template,
163
+ inputs=[
164
+ gr.Image(type="pil", label="Initial Image"),
165
+ gr.Image(type="pil", label="Logo"),
166
+ gr.Textbox(label="Prompt"),
167
+ gr.Textbox(label="Button Text"),
168
+ gr.Textbox(label="Punchline"),
169
+ gr.ColorPicker(label="Image Color"),
170
+ gr.ColorPicker(label="Theme Color"),
171
+ ],
172
+ outputs=[gr.Image(type="pil")],
173
  title="Ad Template Generation Using Diffusion Models Demo",
174
  description="Generate ad template based on your inputs using a trained model.",
175
  concurrency_limit=2,
176
+ examples=[
177
+ [
178
+ "/path/to/example_initial_image1.jpg", # Initial Image
179
+ "/path/to/example_logo1.png", # Logo
180
+ "A scenic mountain landscape", # Prompt
181
+ "Discover More", # Button Text
182
+ "Escape into Nature", # Punchline
183
+ "#00FF00", # Image Color
184
+ "#0000FF", # Theme Color
185
+ ]
186
+ ],
187
  )
188
 
189
  # Run the interface
190
+ iface.launch(debug=True)
assets/Montserrat-Bold.ttf ADDED
Binary file (29.6 kB). View file
 
assets/Montserrat-Regular.ttf ADDED
Binary file (29 kB). View file
 
assets/city_image.jpg CHANGED
assets/logo.jpg DELETED
Binary file (10.1 kB)
 
assets/logo.png ADDED
requirements.txt CHANGED
@@ -3,6 +3,6 @@ transformers
3
  --extra-index-url https://download.pytorch.org/whl/cu113
4
  torch
5
  gradio
6
- PIL
7
  requests
8
  safetensors
 
3
  --extra-index-url https://download.pytorch.org/whl/cu113
4
  torch
5
  gradio
6
+ pillow
7
  requests
8
  safetensors