Handle model parallelism
#4
by
sgugger
- opened
- modeling_codet5p.py +1 -0
modeling_codet5p.py
CHANGED
@@ -927,6 +927,7 @@ class CodeT5pEncoderDecoderModel(PreTrainedModel):
|
|
927 |
loss = None
|
928 |
if labels is not None:
|
929 |
# warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
|
|
930 |
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
931 |
loss_fct = CrossEntropyLoss()
|
932 |
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
|
|
|
927 |
loss = None
|
928 |
if labels is not None:
|
929 |
# warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
930 |
+
labels = labels.to(logits.device)
|
931 |
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
932 |
loss_fct = CrossEntropyLoss()
|
933 |
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
|