Prgckwb commited on
Commit
5fdc31e
1 Parent(s): 86571af

:tada: init

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -301,7 +301,10 @@ def inference(
301
  attn_probs = cross_attention_probs[i, j].cpu().detach().numpy()
302
  ax[j].imshow(attn_probs, alpha=0.9)
303
  ax[j].axis('off')
304
- ax[j].set_title(tokenizer.decode(input_ids[0, j].item()))
 
 
 
305
 
306
  # 各行ごとの画像を保存
307
  out_dir = Path("output")
 
301
  attn_probs = cross_attention_probs[i, j].cpu().detach().numpy()
302
  ax[j].imshow(attn_probs, alpha=0.9)
303
  ax[j].axis('off')
304
+ if has_include_special_tokens:
305
+ ax[j].set_title(tokenizer.decode(input_ids[0, j].item()))
306
+ else:
307
+ ax[j].set_title(tokenizer.decode(input_ids[0, j + 1].item()))
308
 
309
  # 各行ごとの画像を保存
310
  out_dir = Path("output")