Only the logits for the decoder_input_ids are returned, not for the actual input_features
In the sample code from the model card, only the logits for the language token are returned and not the logits of the actual audio. I cannot use the generate function as I need the logits to compute the word level timestamps and to use it with a language model. Is there a way to obtain the logits?
>>> # Generate logits
>>> logits = model(input_features, decoder_input_ids = torch.tensor([[50258]])).logits
>>> # take argmax and decode
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = processor.batch_decode(predicted_ids)
['<|en|>']
You have to run model.generate(...) to get more than the first token
In this case, I don't understand how the evaluation was performed. Here the logits are extracted and the ids are decoded. When I run this locally, I indeed only get an empty transcription after normalization, so I'm wondering how it was evaluated.
>>> librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").to("cuda")
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-large")
>>> def map_to_pred(batch):
>>> input_features = processor(batch["audio"]["array"], return_tensors="pt").input_features
>>> with torch.no_grad():
>>> logits = model(input_features.to("cuda")).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = processor.batch_decode(predicted_ids, normalize = True)
>>> batch['text'] = processor.tokenizer._normalize(batch['text'])
>>> batch["transcription"] = transcription
>>> return batch
>>> result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
>>> print("WER:", wer(result["text"], result["transcription"]))
0.030003583080317572
Hey ! You are correct, the snippet is wrong, we indeed used generate! Will fix the evaluation code . Thanks for the catch
For reference, a corrected code-snippet exists here (just swap base model for large): https://github.com/openai/whisper/blob/a40c75e35cd62b7779774e636b3d081d9cbff82f/README.md#use-in--transformers