Prgckwb commited on
Commit
6a91e71
1 Parent(s): 3c9c988

:tada: init

Browse files
Files changed (2) hide show
  1. app.py +23 -7
  2. assets/ramen.jpg +0 -0
app.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import torch
 
10
  import torch.nn.functional as F
11
  from PIL import Image
12
  from pathlib import Path
@@ -188,7 +189,11 @@ unet.set_attn_processor(
188
 
189
 
190
  @torch.inference_mode()
191
- def inference(image_path: str, prompt: str, progress=gr.Progress(track_tqdm=False)):
 
 
 
 
192
  progress(0, "Initializing...")
193
  image = Image.open(image_path)
194
  image = image.convert("RGB").resize((512, 512))
@@ -269,7 +274,11 @@ def inference(image_path: str, prompt: str, progress=gr.Progress(track_tqdm=Fals
269
  ).squeeze(0) # (77, 512, 512)
270
 
271
  # <bos> と <eos> トークンの間に挿入されたトークンのみを取得
272
- mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512)
 
 
 
 
273
  cross_attention_probs_list.append(mean_cross_attn_probs)
274
 
275
  # list -> torch.Tensor
@@ -281,7 +290,10 @@ def inference(image_path: str, prompt: str, progress=gr.Progress(track_tqdm=Fals
281
  image_list = []
282
  # 各行ごとに画像を作成し保存
283
  for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."):
284
- fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4)) # 行ごとに画像を作成
 
 
 
285
 
286
  for j in range(cross_attention_probs.shape[1]):
287
  # 各クラスのアテンションマップを Min-Max 正規化 (0~1)
@@ -297,12 +309,15 @@ def inference(image_path: str, prompt: str, progress=gr.Progress(track_tqdm=Fals
297
  # 各行ごとの画像を保存
298
  out_dir = Path("output")
299
  out_dir.mkdir(exist_ok=True)
300
- filepath = out_dir / f"output_row_{i}.png"
 
 
301
  plt.savefig(filepath, bbox_inches='tight', pad_inches=0)
302
  plt.close(fig)
303
 
304
  # 保存した画像をPILで読み込んでリストに追加
305
  image_list.append(Image.open(filepath))
 
306
  return image_list
307
 
308
 
@@ -333,13 +348,14 @@ if __name__ == '__main__':
333
  fn=inference,
334
  inputs=[
335
  gr.Image(type="filepath", label="Input", width=512, height=512),
336
- gr.Textbox(label="Prompt", placeholder="e.g.) A photo of dog...")
 
337
  ],
338
  outputs=ca_output,
339
  cache_examples=True,
340
  examples=[
341
- ["assets/aeroplane.png", "plane background"],
342
- ["assets/dogcat.png", "a photo of dog and cat"],
343
  ]
344
  )
345
 
 
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  import torch
10
+ import uuid
11
  import torch.nn.functional as F
12
  from PIL import Image
13
  from pathlib import Path
 
189
 
190
 
191
  @torch.inference_mode()
192
+ def inference(
193
+ image_path: str,
194
+ prompt: str,
195
+ has_include_special_tokens: bool = False,
196
+ progress=gr.Progress(track_tqdm=False)):
197
  progress(0, "Initializing...")
198
  image = Image.open(image_path)
199
  image = image.convert("RGB").resize((512, 512))
 
274
  ).squeeze(0) # (77, 512, 512)
275
 
276
  # <bos> と <eos> トークンの間に挿入されたトークンのみを取得
277
+ if has_include_special_tokens:
278
+ mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512)
279
+ else:
280
+ mean_cross_attn_probs = mean_cross_attn_probs[1:n_cond_tokens - 1, ...]
281
+
282
  cross_attention_probs_list.append(mean_cross_attn_probs)
283
 
284
  # list -> torch.Tensor
 
290
  image_list = []
291
  # 各行ごとに画像を作成し保存
292
  for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."):
293
+ if has_include_special_tokens:
294
+ fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4))
295
+ else:
296
+ fig, ax = plt.subplots(1, n_cond_tokens - 2, figsize=(16, 4))
297
 
298
  for j in range(cross_attention_probs.shape[1]):
299
  # 各クラスのアテンションマップを Min-Max 正規化 (0~1)
 
309
  # 各行ごとの画像を保存
310
  out_dir = Path("output")
311
  out_dir.mkdir(exist_ok=True)
312
+ # 一意なランダムファイル名を生成
313
+ unique_filename = str(uuid.uuid4())
314
+ filepath = out_dir / f"{unique_filename}.png"
315
  plt.savefig(filepath, bbox_inches='tight', pad_inches=0)
316
  plt.close(fig)
317
 
318
  # 保存した画像をPILで読み込んでリストに追加
319
  image_list.append(Image.open(filepath))
320
+ attn_processor.reset_attention_stores()
321
  return image_list
322
 
323
 
 
348
  fn=inference,
349
  inputs=[
350
  gr.Image(type="filepath", label="Input", width=512, height=512),
351
+ gr.Textbox(label="Prompt", placeholder="e.g.) A photo of dog..."),
352
+ gr.Checkbox(label="Include Special Tokens", value=False),
353
  ],
354
  outputs=ca_output,
355
  cache_examples=True,
356
  examples=[
357
+ ["assets/aeroplane.png", "plane background", False],
358
+ ["assets/dogcat.png", "a photo of dog", False],
359
  ]
360
  )
361
 
assets/ramen.jpg ADDED