Files changed (1) hide show
  1. modeling_baichuan.py +6 -3
modeling_baichuan.py CHANGED
@@ -358,9 +358,12 @@ class BaichuanModel(BaichuanPreTrainedModel):
358
  expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
359
  inputs_embeds.device
360
  )
361
- combined_attention_mask = (
362
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
363
- )
 
 
 
364
 
365
  return combined_attention_mask
366
 
 
358
  expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
359
  inputs_embeds.device
360
  )
361
+ if combined_attention_mask is None:
362
+ combined_attention_mask = expanded_attn_mask
363
+ else:
364
+ expanded_attn_mask = torch.where(expanded_attn_mask == torch.finfo(inputs_embeds.dtype).min, torch.finfo(inputs_embeds.dtype).min / 2, expanded_attn_mask)
365
+ combined_attention_mask = torch.where(combined_attention_mask == torch.finfo(inputs_embeds.dtype).min, torch.finfo(inputs_embeds.dtype).min / 2, combined_attention_mask)
366
+ combined_attention_mask = expanded_attn_mask + combined_attention_mask
367
 
368
  return combined_attention_mask
369