Prgckwb commited on
Commit
86571af
1 Parent(s): 6a91e71

:tada: init

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -277,7 +277,7 @@ def inference(
277
  if has_include_special_tokens:
278
  mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512)
279
  else:
280
- mean_cross_attn_probs = mean_cross_attn_probs[1:n_cond_tokens - 1, ...]
281
 
282
  cross_attention_probs_list.append(mean_cross_attn_probs)
283
 
@@ -290,10 +290,7 @@ def inference(
290
  image_list = []
291
  # 各行ごとに画像を作成し保存
292
  for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."):
293
- if has_include_special_tokens:
294
- fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4))
295
- else:
296
- fig, ax = plt.subplots(1, n_cond_tokens - 2, figsize=(16, 4))
297
 
298
  for j in range(cross_attention_probs.shape[1]):
299
  # 各クラスのアテンションマップを Min-Max 正規化 (0~1)
 
277
  if has_include_special_tokens:
278
  mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512)
279
  else:
280
+ mean_cross_attn_probs = mean_cross_attn_probs[1:n_cond_tokens - 1, ...] # (n_tokens-2, 512, 512)
281
 
282
  cross_attention_probs_list.append(mean_cross_attn_probs)
283
 
 
290
  image_list = []
291
  # 各行ごとに画像を作成し保存
292
  for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."):
293
+ fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4))
 
 
 
294
 
295
  for j in range(cross_attention_probs.shape[1]):
296
  # 各クラスのアテンションマップを Min-Max 正規化 (0~1)