Commit
•
53d0b79
1
Parent(s):
e660e78
Update model_class.py (#5)
Browse files- Update model_class.py (c0a6c8d4ea3905485ea20f4f16f6a87c09538893)
Co-authored-by: weilai <[email protected]>
- model_class.py +2 -0
model_class.py
CHANGED
@@ -26,6 +26,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
|
|
26 |
output_hidden_states: Optional[bool] = None,
|
27 |
return_dict: Optional[bool] = None,
|
28 |
forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
|
|
|
29 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
30 |
return super().forward(
|
31 |
input_features=input_features,
|
@@ -43,6 +44,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
|
|
43 |
output_attentions=output_attentions,
|
44 |
output_hidden_states=output_hidden_states,
|
45 |
return_dict=return_dict,
|
|
|
46 |
)
|
47 |
|
48 |
# copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
|
|
|
26 |
output_hidden_states: Optional[bool] = None,
|
27 |
return_dict: Optional[bool] = None,
|
28 |
forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
|
29 |
+
decoder_position_ids: Optional[torch.LongTensor] = None,
|
30 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
31 |
return super().forward(
|
32 |
input_features=input_features,
|
|
|
44 |
output_attentions=output_attentions,
|
45 |
output_hidden_states=output_hidden_states,
|
46 |
return_dict=return_dict,
|
47 |
+
decoder_position_ids=decoder_position_ids,
|
48 |
)
|
49 |
|
50 |
# copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
|