Update modeling_phi.py
Browse files- modeling_phi.py +4 -1
modeling_phi.py
CHANGED
@@ -362,7 +362,10 @@ class PhiAttention(nn.Module):
|
|
362 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
363 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
364 |
|
365 |
-
|
|
|
|
|
|
|
366 |
|
367 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
368 |
raise ValueError(
|
|
|
362 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
363 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
364 |
|
365 |
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
366 |
+
attn_weights = torch.matmul(
|
367 |
+
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
368 |
+
) / math.sqrt(self.head_dim)
|
369 |
|
370 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
371 |
raise ValueError(
|