nyanko7 commited on
Commit
2ea58e8
1 Parent(s): 4508ef4

Update modules/model.py

Browse files
Files changed (1) hide show
  1. modules/model.py +1 -1
modules/model.py CHANGED
@@ -78,7 +78,7 @@ class CrossAttnProcessor(nn.Module):
78
  attention_mask=None,
79
  ):
80
  batch_size, sequence_length, _ = hidden_states.shape
81
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
82
 
83
  encoder_states = hidden_states
84
  is_xattn = False
 
78
  attention_mask=None,
79
  ):
80
  batch_size, sequence_length, _ = hidden_states.shape
81
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size)
82
 
83
  encoder_states = hidden_states
84
  is_xattn = False