Fix wrong tuple count issue after reapply

#6
Files changed (1) hide show
  1. modeling_mpt.py +6 -8
modeling_mpt.py CHANGED
@@ -248,7 +248,7 @@ class MPTModel(MPTPreTrainedModel):
248
 
249
  return custom_forward
250
 
251
- (x, past_key_value) = torch.utils.checkpoint.checkpoint(
252
  create_custom_forward(block),
253
  x,
254
  past_key_value,
@@ -256,15 +256,13 @@ class MPTModel(MPTPreTrainedModel):
256
  attention_mask,
257
  self.is_causal,
258
  )
259
- if past_key_values is not None:
260
- past_key_values[b_idx] = past_key_value
261
  else:
262
  (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
263
- if presents is not None:
264
- presents += (present,)
265
- if output_attentions:
266
- assert all_self_attns is not None
267
- all_self_attns = all_self_attns + (attn_weights,)
268
 
269
 
270
  x = self.norm_f(x)
 
248
 
249
  return custom_forward
250
 
251
+ (x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
252
  create_custom_forward(block),
253
  x,
254
  past_key_value,
 
256
  attention_mask,
257
  self.is_causal,
258
  )
 
 
259
  else:
260
  (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
261
+ if presents is not None:
262
+ presents += (present,)
263
+ if output_attentions:
264
+ assert all_self_attns is not None
265
+ all_self_attns = all_self_attns + (attn_weights,)
266
 
267
 
268
  x = self.norm_f(x)