Handle model parallelism

#4
by sgugger - opened
Files changed (1) hide show
  1. modeling_codet5p.py +1 -0
modeling_codet5p.py CHANGED
@@ -927,6 +927,7 @@ class CodeT5pEncoderDecoderModel(PreTrainedModel):
927
  loss = None
928
  if labels is not None:
929
  # warnings.warn(DEPRECATION_WARNING, FutureWarning)
 
930
  logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
931
  loss_fct = CrossEntropyLoss()
932
  loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
 
927
  loss = None
928
  if labels is not None:
929
  # warnings.warn(DEPRECATION_WARNING, FutureWarning)
930
+ labels = labels.to(logits.device)
931
  logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
932
  loss_fct = CrossEntropyLoss()
933
  loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))