robinzixuan
commited on
Commit
•
3fa0828
1
Parent(s):
b36e70f
Update modeling_opt.py
Browse files- modeling_opt.py +4 -3
modeling_opt.py
CHANGED
@@ -32,6 +32,7 @@ from transformers.modeling_outputs import (
|
|
32 |
QuestionAnsweringModelOutput,
|
33 |
SequenceClassifierOutputWithPast,
|
34 |
)
|
|
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
from transformers.utils import (
|
37 |
add_code_sample_docstrings,
|
@@ -259,10 +260,10 @@ class OPTAttention(nn.Module):
|
|
259 |
|
260 |
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
261 |
if attn_weights.dtype == torch.float16:
|
262 |
-
attn_weights =
|
263 |
attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
264 |
else:
|
265 |
-
attn_weights =
|
266 |
|
267 |
if layer_head_mask is not None:
|
268 |
if layer_head_mask.size() != (self.num_heads,):
|
@@ -489,7 +490,7 @@ class OPTOutEffHop(OPTAttention):
|
|
489 |
return attn_output, attn_weights_reshaped, past_key_value
|
490 |
|
491 |
|
492 |
-
class OptFlashAttention2(
|
493 |
"""
|
494 |
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
|
495 |
The only required change would be on the forward pass where it needs to correctly call the public API of flash
|
|
|
32 |
QuestionAnsweringModelOutput,
|
33 |
SequenceClassifierOutputWithPast,
|
34 |
)
|
35 |
+
|
36 |
from transformers.modeling_utils import PreTrainedModel
|
37 |
from transformers.utils import (
|
38 |
add_code_sample_docstrings,
|
|
|
260 |
|
261 |
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
262 |
if attn_weights.dtype == torch.float16:
|
263 |
+
attn_weights = softmax_1(
|
264 |
attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
265 |
else:
|
266 |
+
attn_weights = softmax_1(attn_weights, dim=-1)
|
267 |
|
268 |
if layer_head_mask is not None:
|
269 |
if layer_head_mask.size() != (self.num_heads,):
|
|
|
490 |
return attn_output, attn_weights_reshaped, past_key_value
|
491 |
|
492 |
|
493 |
+
class OptFlashAttention2(OPTAttention):
|
494 |
"""
|
495 |
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
|
496 |
The only required change would be on the forward pass where it needs to correctly call the public API of flash
|