changwangss commited on
Commit
14d5b0e
1 Parent(s): 0ef0739

update modeling_baichuan.py for torchscript mode with past_kv

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +4 -2
modeling_baichuan.py CHANGED
@@ -285,7 +285,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
285
  use_cache: Optional[bool] = False,
286
  output_attentions: Optional[bool] = False,
287
  output_hidden_states: Optional[bool] = False,
288
- return_dict: Optional[bool] = True,
289
  ) -> Union[Tuple, BaseModelOutputWithPast]:
290
 
291
  if input_ids is not None and inputs_embeds is not None:
@@ -297,6 +297,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
297
  else:
298
  raise ValueError("You need to provide input_ids or inputs_embeds")
299
 
 
 
300
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
301
 
302
  seq_length_with_past = seq_length
@@ -437,7 +439,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
437
  use_cache: Optional[bool] = None,
438
  output_attentions: Optional[bool] = False,
439
  output_hidden_states: Optional[bool] = False,
440
- return_dict: Optional[bool] = True,
441
  **kwargs
442
  ) -> Union[Tuple, CausalLMOutputWithPast]:
443
 
 
285
  use_cache: Optional[bool] = False,
286
  output_attentions: Optional[bool] = False,
287
  output_hidden_states: Optional[bool] = False,
288
+ return_dict: Optional[bool] = None,
289
  ) -> Union[Tuple, BaseModelOutputWithPast]:
290
 
291
  if input_ids is not None and inputs_embeds is not None:
 
297
  else:
298
  raise ValueError("You need to provide input_ids or inputs_embeds")
299
 
300
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
301
+
302
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
303
 
304
  seq_length_with_past = seq_length
 
439
  use_cache: Optional[bool] = None,
440
  output_attentions: Optional[bool] = False,
441
  output_hidden_states: Optional[bool] = False,
442
+ return_dict: Optional[bool] = None,
443
  **kwargs
444
  ) -> Union[Tuple, CausalLMOutputWithPast]:
445