sanchit-gandhi HF staff commited on
Commit
eb2a3e3
1 Parent(s): f3b5e97

Force <|startoftranscript|>

Browse files

Updates the `forced_decoder_ids` to force the `<|startoftranscript|>` token at position 1. This is to match the official Whisper implementation, which always predicts `<|startoftranscript|>` at position 1:
```python
#!pip install git+https://github.com/openai/whisper.git

import whisper
from datasets import load_dataset

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

model = whisper.load_model("tiny.en").to(device)

tokenizer = whisper.tokenizer.get_tokenizer(False, task="transcribe", language="en")
tokenizer = tokenizer.tokenizer

librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

def to_pad_to_mel(array):
"""Static function which:
1. Pads/trims a list of audio arrays to a max length of 30s
2. Computes log-mel filter coefficients from padded/trimmed audio sequences
Inputs:
array: list of audio arrays
Returns:
input_ids: torch.tensor of log-mel filter bank coefficients
"""
padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32))
input_ids = whisper.log_mel_spectrogram(padded_input)
return input_ids

audio_array = librispeech[0]["audio"]["array"]
log_mel = to_pad_to_mel(audio_array).unsqueeze(0)

tokens = model.generate(log_mel.to(device))[0]
transcript = tokenizer.decode(tokens, skip_special_tokens=False)
print(transcript)
```
**Print Output:**
```
<|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle classes, and we are glad to
```

Files changed (1) hide show
  1. config.json +4 -0
config.json CHANGED
@@ -26,6 +26,10 @@
26
  "forced_decoder_ids": [
27
  [
28
  1,
 
 
 
 
29
  50362
30
  ]
31
  ],
 
26
  "forced_decoder_ids": [
27
  [
28
  1,
29
+ 50257
30
+ ]
31
+ [
32
+ 2,
33
  50362
34
  ]
35
  ],