update modeling_baichuan.py for torchscript mode with past_kv
#2
by
changwangss
- opened
- modeling_baichuan.py +4 -2
modeling_baichuan.py
CHANGED
@@ -285,7 +285,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
285 |
use_cache: Optional[bool] = False,
|
286 |
output_attentions: Optional[bool] = False,
|
287 |
output_hidden_states: Optional[bool] = False,
|
288 |
-
return_dict: Optional[bool] =
|
289 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
290 |
|
291 |
if input_ids is not None and inputs_embeds is not None:
|
@@ -297,6 +297,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
297 |
else:
|
298 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
299 |
|
|
|
|
|
300 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
301 |
|
302 |
seq_length_with_past = seq_length
|
@@ -437,7 +439,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
437 |
use_cache: Optional[bool] = None,
|
438 |
output_attentions: Optional[bool] = False,
|
439 |
output_hidden_states: Optional[bool] = False,
|
440 |
-
return_dict: Optional[bool] =
|
441 |
**kwargs
|
442 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
443 |
|
|
|
285 |
use_cache: Optional[bool] = False,
|
286 |
output_attentions: Optional[bool] = False,
|
287 |
output_hidden_states: Optional[bool] = False,
|
288 |
+
return_dict: Optional[bool] = None,
|
289 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
290 |
|
291 |
if input_ids is not None and inputs_embeds is not None:
|
|
|
297 |
else:
|
298 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
299 |
|
300 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
301 |
+
|
302 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
303 |
|
304 |
seq_length_with_past = seq_length
|
|
|
439 |
use_cache: Optional[bool] = None,
|
440 |
output_attentions: Optional[bool] = False,
|
441 |
output_hidden_states: Optional[bool] = False,
|
442 |
+
return_dict: Optional[bool] = None,
|
443 |
**kwargs
|
444 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
445 |
|