fix mixed precision loading with recent transformers versions

#39
Files changed (1) hide show
  1. modeling_xlm_roberta.py +1 -0
modeling_xlm_roberta.py CHANGED
@@ -404,6 +404,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
404
  config_class = XLMRobertaFlashConfig
405
  base_model_prefix = "roberta"
406
  supports_gradient_checkpointing = True
 
407
 
408
  def _set_gradient_checkpointing(self, module, value=False):
409
  if isinstance(module, XLMRobertaEncoder):
 
404
  config_class = XLMRobertaFlashConfig
405
  base_model_prefix = "roberta"
406
  supports_gradient_checkpointing = True
407
+ _supports_param_buffer_assignment = False
408
 
409
  def _set_gradient_checkpointing(self, module, value=False):
410
  if isinstance(module, XLMRobertaEncoder):