update modeling_baichuan.py for torchscript mode with past_kv
#5
by
changwangss
- opened
- modeling_baichuan.py +4 -2
modeling_baichuan.py
CHANGED
@@ -365,7 +365,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
365 |
use_cache: Optional[bool] = False,
|
366 |
output_attentions: Optional[bool] = False,
|
367 |
output_hidden_states: Optional[bool] = False,
|
368 |
-
return_dict: Optional[bool] =
|
369 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
370 |
if input_ids is not None and inputs_embeds is not None:
|
371 |
raise ValueError(
|
@@ -378,6 +378,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
378 |
else:
|
379 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
380 |
|
|
|
|
|
381 |
return_dict = (
|
382 |
return_dict if return_dict is not None else self.config.use_return_dict
|
383 |
)
|
@@ -682,7 +684,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
682 |
use_cache: Optional[bool] = None,
|
683 |
output_attentions: Optional[bool] = False,
|
684 |
output_hidden_states: Optional[bool] = False,
|
685 |
-
return_dict: Optional[bool] =
|
686 |
**kwargs,
|
687 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
688 |
return_dict = (
|
|
|
365 |
use_cache: Optional[bool] = False,
|
366 |
output_attentions: Optional[bool] = False,
|
367 |
output_hidden_states: Optional[bool] = False,
|
368 |
+
return_dict: Optional[bool] = None,
|
369 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
370 |
if input_ids is not None and inputs_embeds is not None:
|
371 |
raise ValueError(
|
|
|
378 |
else:
|
379 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
380 |
|
381 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
382 |
+
|
383 |
return_dict = (
|
384 |
return_dict if return_dict is not None else self.config.use_return_dict
|
385 |
)
|
|
|
684 |
use_cache: Optional[bool] = None,
|
685 |
output_attentions: Optional[bool] = False,
|
686 |
output_hidden_states: Optional[bool] = False,
|
687 |
+
return_dict: Optional[bool] = None,
|
688 |
**kwargs,
|
689 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
690 |
return_dict = (
|