robinzixuan commited on
Commit
3fa0828
1 Parent(s): b36e70f

Update modeling_opt.py

Browse files
Files changed (1) hide show
  1. modeling_opt.py +4 -3
modeling_opt.py CHANGED
@@ -32,6 +32,7 @@ from transformers.modeling_outputs import (
32
  QuestionAnsweringModelOutput,
33
  SequenceClassifierOutputWithPast,
34
  )
 
35
  from transformers.modeling_utils import PreTrainedModel
36
  from transformers.utils import (
37
  add_code_sample_docstrings,
@@ -259,10 +260,10 @@ class OPTAttention(nn.Module):
259
 
260
  # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
261
  if attn_weights.dtype == torch.float16:
262
- attn_weights = nn.functional.softmax(
263
  attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
264
  else:
265
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
266
 
267
  if layer_head_mask is not None:
268
  if layer_head_mask.size() != (self.num_heads,):
@@ -489,7 +490,7 @@ class OPTOutEffHop(OPTAttention):
489
  return attn_output, attn_weights_reshaped, past_key_value
490
 
491
 
492
- class OptFlashAttention2(OPTOutEffHop):
493
  """
494
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
495
  The only required change would be on the forward pass where it needs to correctly call the public API of flash
 
32
  QuestionAnsweringModelOutput,
33
  SequenceClassifierOutputWithPast,
34
  )
35
+
36
  from transformers.modeling_utils import PreTrainedModel
37
  from transformers.utils import (
38
  add_code_sample_docstrings,
 
260
 
261
  # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
262
  if attn_weights.dtype == torch.float16:
263
+ attn_weights = softmax_1(
264
  attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
265
  else:
266
+ attn_weights = softmax_1(attn_weights, dim=-1)
267
 
268
  if layer_head_mask is not None:
269
  if layer_head_mask.size() != (self.num_heads,):
 
490
  return attn_output, attn_weights_reshaped, past_key_value
491
 
492
 
493
+ class OptFlashAttention2(OPTAttention):
494
  """
495
  OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
496
  The only required change would be on the forward pass where it needs to correctly call the public API of flash