File size: 24,284 Bytes
b2d7654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math

import jax
import jax.numpy as jnp
import numpy as np
import requests
from flax import jax_utils
from flax.core.frozen_dict import freeze
from flax.training.common_utils import shard
from jax.sharding import PartitionSpec as P
from transformers import WhisperProcessor, is_tokenizers_available, WhisperFeatureExtractor, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, WhisperTokenizer
from transformers.pipelines.audio_utils import ffmpeg_read
from transformers.utils import logging

from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
from .partitioner import PjitPartitioner
from .train_state import InferenceState


logger = logging.get_logger(__name__)

# 2D parameter and activation partitioning for DP
logical_axis_rules_dp = (
    ("batch", "data"),
    ("mlp", None),
    ("heads", None),
    ("vocab", None),
    ("embed", None),
    ("embed", None),
    ("joined_kv", None),
    ("kv", None),
    ("length", None),
    ("num_mel", None),
    ("channels", None),
)


class FlaxWhisperPipline:
    def __init__(
        self,
        checkpoint="openai/whisper-large-v2",
        dtype=jnp.float32,
        batch_size=None,
        max_length=None,
    ):
        """
        Args
            checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"):
                The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub
                with Flax weights.
            dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
                The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
                `jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs.
                If specified all the computation will be performed with the given `dtype`. **Note that this only
                specifies the dtype of the computation and does not influence the dtype of model parameters.**
            batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
                The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
                a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method.
            max_length (`int`, *optional*):
                The maximum numbers of tokens to generate. Defaults to `model.config.max_length`.
        """
        self.checkpoint = checkpoint
        self.dtype = dtype

        self.processor = WhisperProcessor.from_pretrained(self.checkpoint)
        self.feature_extractor = self.processor.feature_extractor
        # potentially load fast tokenizer if available
        tokenizer_cls = WhisperTokenizerFast if is_tokenizers_available() else WhisperTokenizer
        self.tokenizer = tokenizer_cls.from_pretrained(checkpoint)

        self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained(
            self.checkpoint,
            _do_init=False,
            dtype=self.dtype,
        )

        self.max_length = max_length if max_length is not None else self.model.generation_config.max_length
        self.min_batch_size = jax.local_device_count()
        self.batch_size = (
            batch_size if batch_size is not None else self.min_batch_size
        )  # we need a minimum of 1 batch per-device

        def generate(params, input_features, forced_decoder_ids, return_timestamps):
            output_ids = self.model.pipeline_generate(
                input_features,
                params=params,
                forced_decoder_ids=forced_decoder_ids,
                return_timestamps=return_timestamps,
                max_length=self.max_length,
            )
            return output_ids

        # use pmap for DP by default - this is compatible on a Colab TPU v2
        self.params = jax_utils.replicate(self.params)
        self.p_generate = jax.pmap(
            generate, "input_features", in_axes=(0, 0, None), out_axes=0, static_broadcasted_argnums=(3,)
        )
        self.is_sharded = False

    def shard_params(self, num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp):
        def init_fn():
            input_shape = (1, self.model.config.num_mel_bins, 2 * self.model.config.max_source_positions)

            input_features = jnp.zeros(input_shape, dtype="f4")
            input_features = input_features.at[(..., -1)].set(self.model.config.eos_token_id)

            decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)

            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

            rng = jax.random.PRNGKey(0)
            init_params = self.model.module.init(
                rng,
                input_features=input_features,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                decoder_position_ids=decoder_position_ids,
                return_dict=False,
            )
            return init_params

        # Axis names metadata
        param_axes = jax.eval_shape(init_fn)["params_axes"]

        # Create InferenceState, since the partitioner expects it
        state = InferenceState(
            step=jnp.array(0),
            params=freeze(self.model.params_shape_tree),
            params_axes=freeze(param_axes),
            flax_mutables=None,
            flax_mutables_axes=param_axes,
        )

        partitioner = PjitPartitioner(num_partitions=num_mp_partitions, logical_axis_rules=logical_axis_rules)

        mesh_axes = partitioner.get_mesh_axes(state)
        params_spec = mesh_axes.params

        p_shard_params = partitioner.partition(self.model.to_bf16, (params_spec,), params_spec)

        # This will auto-magically run in mesh context
        self.params = p_shard_params(freeze(jax_utils.unreplicate(self.params)))
        self.is_sharded = True

        def generate(params, input_features, forced_decoder_ids, return_timestamps):
            output_ids = self.model.pipeline_generate(
                input_features,
                params=params,
                forced_decoder_ids=forced_decoder_ids,
                return_timestamps=return_timestamps,
                max_length=self.max_length,
            )
            return output_ids

        # Use pjit for generate only once we've sharded the params
        self.p_generate = partitioner.partition(
            generate,
            in_axis_resources=(params_spec, P("data"), None),
            out_axis_resources=P("data"),
            static_argnums=(3,),
        )

    def generate(self, input_features, language=None, task=None, return_timestamps=False):
        forced_decoder_ids = self.get_forced_decoder_ids(
            language=language, task=task, return_timestamps=return_timestamps
        )
        if not self.is_sharded:
            # if we're using pmap we need to manually replicate the input data across devices and gather the output tokens
            output_ids = self.p_generate(
                freeze(self.params), shard(input_features), forced_decoder_ids, return_timestamps
            ).sequences
            output_ids = jax.device_get(output_ids.reshape(-1, self.max_length))
        else:
            # pjit handles replication / gathering for us auto-magically
            output_ids = self.p_generate(
                freeze(self.params), input_features, forced_decoder_ids, return_timestamps
            ).sequences
        return output_ids

    def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
        if generation_config is None:
            generation_config = self.model.generation_config

        if hasattr(generation_config, "is_multilingual"):
            is_multilingual = generation_config.is_multilingual
        else:
            is_multilingual = None

        forced_decoder_ids = []

        if is_multilingual:
            if language is not None:
                language = language.lower()
                if language in generation_config.lang_to_id.keys():
                    language_token = language
                elif language in TO_LANGUAGE_CODE.values():
                    language_token = f"<|{language}|>"
                elif language in TO_LANGUAGE_CODE.keys():
                    language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
                else:
                    if len(language) == 2:
                        # ISO 639-1 language code
                        acceptable_languages = list(TO_LANGUAGE_CODE.values())
                    elif "<" in language or "|" in language or ">" in language:
                        # generation config language code
                        acceptable_languages = list(generation_config.lang_to_id.keys())
                    else:
                        # language passed as a string
                        acceptable_languages = list(TO_LANGUAGE_CODE.keys())
                    raise ValueError(
                        f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
                    )
                forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))

            if task is not None:
                forced_decoder_ids.append((2, generation_config.task_to_id[task]))
            else:
                forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))

        if not return_timestamps:
            if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
                idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
                forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))

        return forced_decoder_ids

    def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
        inputs_len = inputs.shape[0]
        step = chunk_len - stride_left - stride_right

        all_chunk_start_idx = np.arange(0, inputs_len, step)
        num_samples = len(all_chunk_start_idx)

        num_batches = math.ceil(num_samples / batch_size)
        batch_idx = np.array_split(np.arange(num_samples), num_batches)

        for idx in batch_idx:
            chunk_start_idx = all_chunk_start_idx[idx]
            chunk_end_idx = chunk_start_idx + chunk_len

            chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
            processed = self.feature_extractor(
                chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
            )

            _stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
            is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
            _stride_right = np.where(is_last, 0, stride_right)

            chunk_lens = [chunk.shape[0] for chunk in chunks]
            strides = [
                (chunk_l, _stride_l, _stride_r)
                for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
            ]

            yield {"stride": strides, **processed}

    def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None):
        if isinstance(inputs, np.ndarray):
            logger.warning(
                "Numpy array passed as input - no sampling rate checks will be performed."
                "It is strongly recommended to pass the input as a dictionary with an 'array' key "
                "containing the numpy array representing the audio, and a 'sampling_rate' key "
                "containing the sampling rate associated with the audio array."
                "Failing to do so can result in silent errors that might be hard to debug."
            )

        if isinstance(inputs, str):
            if inputs.startswith("http://") or inputs.startswith("https://"):
                # We need to actually check for a real protocol, otherwise it's impossible to use a local file
                # like http_huggingface_co.png
                inputs = requests.get(inputs).content
            else:
                with open(inputs, "rb") as f:
                    inputs = f.read()

        if isinstance(inputs, bytes):
            inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)

        stride = None
        if isinstance(inputs, dict):
            stride = inputs.get("stride", None)
            # Accepting `"array"` which is the key defined in `datasets` for
            # better integration
            if not ("sampling_rate" in inputs and "array" in inputs):
                raise ValueError(
                    "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key "
                    "containing the numpy array representing the audio, and a 'sampling_rate' key "
                    "containing the sampling rate associated with the audio array."
                )

            in_sampling_rate = inputs.get("sampling_rate")
            inputs = inputs.get("array", None)

            if in_sampling_rate != self.feature_extractor.sampling_rate:
                try:
                    import librosa
                except ImportError as err:
                    raise ImportError(
                        "To support resampling audio files, please install 'librosa' and 'soundfile'."
                    ) from err

                inputs = librosa.resample(
                    inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
                )
                ratio = self.feature_extractor.sampling_rate / in_sampling_rate
            else:
                ratio = 1

        if not isinstance(inputs, np.ndarray):
            raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
        if len(inputs.shape) != 1:
            raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")

        if stride is not None:
            if stride[0] + stride[1] > inputs.shape[0]:
                raise ValueError("Stride is too large for input")

            # Stride needs to get the chunk length here, it's going to get
            # swallowed by the `feature_extractor` later, and then batching
            # can add extra data in the inputs, so we need to keep track
            # of the original length in the stride so we can cut properly.
            stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))

        if chunk_length_s:
            if stride_length_s is None:
                stride_length_s = chunk_length_s / 6

            if isinstance(stride_length_s, (int, float)):
                stride_length_s = [stride_length_s, stride_length_s]

            chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
            stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
            stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)

            if chunk_len < stride_left + stride_right:
                raise ValueError("Chunk length must be superior to stride length")

            for item in self.chunk_iter_with_batch(
                inputs,
                chunk_len,
                stride_left,
                stride_right,
                batch_size,
            ):
                yield item
        else:
            processed = self.feature_extractor(
                inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
            )
            if stride is not None:
                processed["stride"] = stride
            yield processed

    def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
        # unpack the outputs from list(dict(list)) to list(dict)
        model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]

        time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
        # Send the chunking back to seconds, it's easier to handle in whisper
        sampling_rate = self.feature_extractor.sampling_rate
        for output in model_outputs:
            if "stride" in output:
                chunk_len, stride_left, stride_right = output["stride"]
                # Go back in seconds
                chunk_len /= sampling_rate
                stride_left /= sampling_rate
                stride_right /= sampling_rate
                output["stride"] = chunk_len, stride_left, stride_right

        text, optional = self.tokenizer._decode_asr(
            model_outputs,
            return_timestamps=return_timestamps,
            return_language=return_language,
            time_precision=time_precision,
        )
        return {"text": text, **optional}

    def forward(self, model_inputs, batch_size=None, language=None, task=None, return_timestamps=False):
        # We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation
        input_features = model_inputs.pop("input_features")
        input_batch_size = input_features.shape[0]

        if input_batch_size != batch_size:
            padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype)
            input_features = np.concatenate([input_features, padding])

        pred_ids = self.generate(input_features, language=language, task=task, return_timestamps=return_timestamps)[
            :input_batch_size
        ]

        # tokenizer's decode method expects an extra dim - we insert it here for convenience
        out = {"tokens": pred_ids[:, None, :]}

        stride = model_inputs.pop("stride", None)
        if stride is not None:
            out["stride"] = stride

        return out

    def __call__(
        self,
        inputs,
        chunk_length_s=30.0,
        stride_length_s=None,
        batch_size=None,
        language=None,
        task=None,
        return_timestamps=None,
        generate_kwargs=None,
    ):
        """
        Transcribe an audio input sequence to a text transcription, optionally with timestamps.

        Args:
            inputs (`np.ndarray` or `bytes` or `str` or `dict`):
                The inputs is either:
                    - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
                      to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
                    - `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the
                      same way.
                    - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
                        Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling
                        rate check will be done.
                    - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
                      pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array":
                      np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to
                       ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in
                       decoding (but used at inference to provide more context to the model). In general, this additional
                       stride argument is not required.
            chunk_length_s (`float`, *optional*, defaults to 30.0):
                The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk
                length is set 30.0s, equal to Whisper's context window.
            stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
                The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
                the model to *see* more context and infer letters better than without this context but the pipeline
                discards the stride bits at the end to make the final reconstitution as perfect as possible.

                <Tip>

                For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking
                blog post](https://huggingface.co/blog/asr-chunking).

                </Tip>
            batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
                The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
                a batch size in the `__call__` method will supersede any batch size passed to the `__init__`.
            task (`str`, *optional*):
                Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
            language (`str`, *optional*):
                Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`.
                Defaults to `None`, meaning the language is automatically inferred from the audio input.
            return_timestamps (*optional*, `bool`):
                Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline
                will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"`
                containing the transcription segments chunked by their utterance-level timestamps.

        Return:
            `Dict`: A dictionary with the following keys:
                - **text** (`str` ) -- The recognised text.
                - **chunks** (*optional(, `List[Dict]`)
                    When using `return_timestamps`, the `chunks` will become a list containing all the various text
                    chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
                    "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
                    `"".join(chunk["text"] for chunk in output["chunks"])`.
        """
        batch_size = batch_size if batch_size is not None else self.batch_size
        if batch_size % self.min_batch_size != 0:
            raise ValueError(
                f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}."
            )

        dataloader = self.preprocess_batch(
            inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size
        )
        model_outputs = []
        # iterate over our chunked audio samples
        for batch in dataloader:
            model_outputs.append(
                self.forward(
                    batch, batch_size=batch_size, language=language, task=task, return_timestamps=return_timestamps
                )
            )
        post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps)
        return post_processed