fix left padding batch infer
#9
by
kuaizhirui
- opened
- modeling_baichuan.py +7 -4
modeling_baichuan.py
CHANGED
@@ -358,10 +358,13 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
358 |
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
359 |
inputs_embeds.device
|
360 |
)
|
361 |
-
combined_attention_mask
|
362 |
-
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
365 |
return combined_attention_mask
|
366 |
|
367 |
def forward(
|
|
|
358 |
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
359 |
inputs_embeds.device
|
360 |
)
|
361 |
+
if combined_attention_mask is None:
|
362 |
+
combined_attention_mask = expanded_attn_mask
|
363 |
+
else:
|
364 |
+
expanded_attn_mask = torch.where(expanded_attn_mask == torch.finfo(inputs_embeds.dtype).min, torch.finfo(inputs_embeds.dtype).min / 2, expanded_attn_mask)
|
365 |
+
combined_attention_mask = torch.where(combined_attention_mask == torch.finfo(inputs_embeds.dtype).min, torch.finfo(inputs_embeds.dtype).min / 2, expanded_attn_mask)
|
366 |
+
combined_attention_mask = expanded_attn_mask + combined_attention_mask
|
367 |
+
|
368 |
return combined_attention_mask
|
369 |
|
370 |
def forward(
|