Fix wrong tuple count issue after reapply
#6
by
kornfield
- opened
- modeling_mpt.py +6 -8
modeling_mpt.py
CHANGED
@@ -248,7 +248,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
248 |
|
249 |
return custom_forward
|
250 |
|
251 |
-
(x,
|
252 |
create_custom_forward(block),
|
253 |
x,
|
254 |
past_key_value,
|
@@ -256,15 +256,13 @@ class MPTModel(MPTPreTrainedModel):
|
|
256 |
attention_mask,
|
257 |
self.is_causal,
|
258 |
)
|
259 |
-
if past_key_values is not None:
|
260 |
-
past_key_values[b_idx] = past_key_value
|
261 |
else:
|
262 |
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
|
269 |
|
270 |
x = self.norm_f(x)
|
|
|
248 |
|
249 |
return custom_forward
|
250 |
|
251 |
+
(x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
|
252 |
create_custom_forward(block),
|
253 |
x,
|
254 |
past_key_value,
|
|
|
256 |
attention_mask,
|
257 |
self.is_causal,
|
258 |
)
|
|
|
|
|
259 |
else:
|
260 |
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|
261 |
+
if presents is not None:
|
262 |
+
presents += (present,)
|
263 |
+
if output_attentions:
|
264 |
+
assert all_self_attns is not None
|
265 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
266 |
|
267 |
|
268 |
x = self.norm_f(x)
|