Afrinetwork7 commited on
Commit
8da2b37
1 Parent(s): 498188c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -200
app.py CHANGED
@@ -1,215 +1,152 @@
1
- import base64
2
  import logging
3
  import math
 
4
  import tempfile
5
  import time
6
- from typing import Optional, Tuple
7
- import os
8
-
9
- import fastapi
10
  import jax.numpy as jnp
11
  import numpy as np
12
- import yt_dlp as youtube_dl
13
- from jax.experimental.compilation_cache import compilation_cache as cc
14
- from pydantic import BaseModel
15
  from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
16
  from transformers.pipelines.audio_utils import ffmpeg_read
17
-
18
  from whisper_jax import FlaxWhisperPipline
19
 
20
- # Set up logging
21
- logging.basicConfig(level=logging.INFO)
22
- logger = logging.getLogger("whisper-jax-app")
23
 
24
- try:
25
- cc.initialize_cache("./jax_cache")
26
- checkpoint = "openai/whisper-large-v3"
27
-
28
- BATCH_SIZE = 32
29
- CHUNK_LENGTH_S = 30
30
- NUM_PROC = 32
31
- FILE_LIMIT_MB = 10000
32
- YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
33
-
34
- pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
35
- stride_length_s = CHUNK_LENGTH_S / 6
36
- chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
37
- stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
38
- step = chunk_len - stride_left - stride_right
39
-
40
- # do a pre-compile step
41
- logger.info("compiling forward call...")
42
- start = time.time()
43
- random_inputs = {
44
- "input_features": np.ones(
45
- (BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions)
46
- )
47
- }
48
- random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
49
- compile_time = time.time() - start
50
- logger.info(f"compiled in {compile_time}s")
51
-
52
- except Exception as e:
53
- logger.error(f"Error during initialization: {str(e)}")
54
- raise
55
-
56
- app = fastapi.FastAPI()
57
-
58
- class TranscriptionRequest(BaseModel):
59
- audio_file: str
60
- task: str = "transcribe"
61
- return_timestamps: bool = False
62
-
63
- class TranscriptionResponse(BaseModel):
64
- transcription: str
65
- runtime: float
66
-
67
- @app.post("/transcribe", response_model=TranscriptionResponse)
68
- def transcribe_audio(request: TranscriptionRequest):
69
- try:
70
- logger.info("loading audio file...")
71
- if not request.audio_file:
72
- logger.warning("No audio file")
73
- raise fastapi.HTTPException(status_code=400, detail="No audio file submitted!")
74
-
75
- audio_bytes = base64.b64decode(request.audio_file)
76
- file_size_mb = len(audio_bytes) / (1024 * 1024)
77
- if file_size_mb > FILE_LIMIT_MB:
78
- logger.warning("Max file size exceeded")
79
- raise fastapi.HTTPException(
80
- status_code=400,
81
- detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.",
82
- )
83
-
84
- inputs = ffmpeg_read(audio_bytes, pipeline.feature_extractor.sampling_rate)
85
- inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
86
- logger.info("done loading")
87
- text, runtime = _tqdm_generate(inputs, task=request.task, return_timestamps=request.return_timestamps)
88
- return TranscriptionResponse(transcription=text, runtime=runtime)
89
- except Exception as e:
90
- logger.error(f"Error in transcribe_audio: {str(e)}")
91
- raise fastapi.HTTPException(status_code=500, detail=f"An error occurred during transcription: {str(e)}")
92
 
93
  @app.post("/transcribe_youtube")
94
- def transcribe_youtube(
95
- yt_url: str, task: str = "transcribe", return_timestamps: bool = False
96
- ) -> Tuple[str, str, float]:
97
- try:
98
- logger.info("loading youtube file...")
99
- html_embed_str = _return_yt_html_embed(yt_url)
100
- with tempfile.TemporaryDirectory() as tmpdirname:
101
- filepath = os.path.join(tmpdirname, "video.mp4")
102
- _download_yt_audio(yt_url, filepath)
103
-
104
- with open(filepath, "rb") as f:
105
- inputs = f.read()
106
-
107
- inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
108
- inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
109
- logger.info("done loading...")
110
- text, runtime = _tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
111
- return html_embed_str, text, runtime
112
- except Exception as e:
113
- logger.error(f"Error in transcribe_youtube: {str(e)}")
114
- raise fastapi.HTTPException(status_code=500, detail=f"An error occurred during YouTube transcription: {str(e)}")
115
-
116
- def _tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
117
- try:
118
- inputs_len = inputs["array"].shape[0]
119
- all_chunk_start_idx = np.arange(0, inputs_len, step)
120
- num_samples = len(all_chunk_start_idx)
121
- num_batches = math.ceil(num_samples / BATCH_SIZE)
122
-
123
- dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
124
- model_outputs = []
125
- start_time = time.time()
126
- logger.info("transcribing...")
127
- for batch, _ in zip(dataloader, range(num_batches)):
128
- model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
129
- runtime = time.time() - start_time
130
- logger.info("done transcription")
131
-
132
- logger.info("post-processing...")
133
- post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
134
- text = post_processed["text"]
135
- if return_timestamps:
136
- timestamps = post_processed.get("chunks")
137
- timestamps = [
138
- f"[{_format_timestamp(chunk['timestamp'][0])} -> {_format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
139
- for chunk in timestamps
140
- ]
141
- text = "\n".join(str(feature) for feature in timestamps)
142
- logger.info("done post-processing")
143
- return text, runtime
144
- except Exception as e:
145
- logger.error(f"Error in _tqdm_generate: {str(e)}")
146
- raise
147
-
148
- def _return_yt_html_embed(yt_url: str) -> str:
149
  try:
150
- video_id = yt_url.split("?v=")[-1]
151
- HTML_str = (
152
- f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
153
- " </center>"
154
- )
155
- return HTML_str
156
- except Exception as e:
157
- logger.error(f"Error in _return_yt_html_embed: {str(e)}")
158
- raise
159
-
160
- def _download_yt_audio(yt_url: str, filename: str):
161
- try:
162
- info_loader = youtube_dl.YoutubeDL()
 
 
 
 
 
 
 
163
  try:
164
- info = info_loader.extract_info(yt_url, download=False)
165
- except youtube_dl.utils.DownloadError as err:
166
- raise fastapi.HTTPException(status_code=400, detail=str(err))
167
-
168
- file_length = info["duration_string"]
169
- file_h_m_s = file_length.split(":")
170
- file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
171
- if len(file_h_m_s) == 1:
172
- file_h_m_s.insert(0, 0)
173
- if len(file_h_m_s) == 2:
174
- file_h_m_s.insert(0, 0)
175
-
176
- file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
177
- if file_length_s > YT_LENGTH_LIMIT_S:
178
- yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
179
- file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
180
- raise fastapi.HTTPException(
181
- status_code=400,
182
- detail=f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.",
183
- )
184
-
185
- ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
186
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
187
- try:
188
- ydl.download([yt_url])
189
- except youtube_dl.utils.ExtractorError as err:
190
- raise fastapi.HTTPException(status_code=400, detail=str(err))
191
- except Exception as e:
192
- logger.error(f"Error in _download_yt_audio: {str(e)}")
193
- raise
194
-
195
- def _format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
196
- try:
197
- if seconds is not None:
198
- milliseconds = round(seconds * 1000.0)
199
-
200
- hours = milliseconds // 3_600_000
201
- milliseconds -= hours * 3_600_000
202
-
203
- minutes = milliseconds // 60_000
204
- milliseconds -= minutes * 60_000
205
-
206
- seconds = milliseconds // 1_000
207
- milliseconds -= seconds * 1_000
208
-
209
- hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
210
- return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
211
- else:
212
- return seconds
213
- except Exception as e:
214
- logger.error(f"Error in _format_timestamp: {str(e)}")
215
- raise
 
 
1
  import logging
2
  import math
3
+ import os
4
  import tempfile
5
  import time
6
+ import yt_dlp as youtube_dl
7
+ from fastapi import FastAPI, UploadFile, Form, HTTPException
8
+ from fastapi.responses import HTMLResponse
 
9
  import jax.numpy as jnp
10
  import numpy as np
 
 
 
11
  from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
12
  from transformers.pipelines.audio_utils import ffmpeg_read
 
13
  from whisper_jax import FlaxWhisperPipline
14
 
15
+ cc.initialize_cache("./jax_cache")
16
+ checkpoint = "openai/whisper-large-v3"
 
17
 
18
+ BATCH_SIZE = 32
19
+ CHUNK_LENGTH_S = 30
20
+ NUM_PROC = 32
21
+ FILE_LIMIT_MB = 10000
22
+ YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
23
+
24
+ app = FastAPI(title="Whisper JAX: The Fastest Whisper API ⚡️")
25
+
26
+ logger = logging.getLogger("whisper-jax-app")
27
+ logger.setLevel(logging.INFO)
28
+ ch = logging.StreamHandler()
29
+ ch.setLevel(logging.INFO)
30
+ formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
31
+ ch.setFormatter(formatter)
32
+ logger.addHandler(ch)
33
+
34
+ pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
35
+ stride_length_s = CHUNK_LENGTH_S / 6
36
+ chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
37
+ stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
38
+ step = chunk_len - stride_left - stride_right
39
+
40
+ # do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
41
+ logger.info("compiling forward call...")
42
+ start = time.time()
43
+ random_inputs = {
44
+ "input_features": np.ones(
45
+ (BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions)
46
+ )
47
+ }
48
+ random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
49
+ compile_time = time.time() - start
50
+ logger.info(f"compiled in {compile_time}s")
51
+
52
+ @app.post("/transcribe_audio")
53
+ async def transcribe_chunked_audio(audio_file: UploadFile, task: str = "transcribe", return_timestamps: bool = False):
54
+ logger.info("loading audio file...")
55
+ if not audio_file:
56
+ logger.warning("No audio file")
57
+ raise HTTPException(status_code=400, detail="No audio file submitted!")
58
+ file_size_mb = os.stat(audio_file.filename).st_size / (1024 * 1024)
59
+ if file_size_mb > FILE_LIMIT_MB:
60
+ logger.warning("Max file size exceeded")
61
+ raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
62
+
63
+ with open(audio_file.filename, "rb") as f:
64
+ inputs = f.read()
65
+
66
+ inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
67
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
68
+ logger.info("done loading")
69
+ text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
70
+ return {"text": text, "runtime": runtime}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  @app.post("/transcribe_youtube")
73
+ async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe", return_timestamps: bool = False):
74
+ logger.info("loading youtube file...")
75
+ html_embed_str = _return_yt_html_embed(yt_url)
76
+ with tempfile.TemporaryDirectory() as tmpdirname:
77
+ filepath = os.path.join(tmpdirname, "video.mp4")
78
+ download_yt_audio(yt_url, filepath)
79
+
80
+ with open(filepath, "rb") as f:
81
+ inputs = f.read()
82
+
83
+ inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
84
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
85
+ logger.info("done loading...")
86
+ text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
87
+ return {"html_embed": html_embed_str, "text": text, "runtime": runtime}
88
+
89
+ def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
90
+ inputs_len = inputs["array"].shape[0]
91
+ all_chunk_start_idx = np.arange(0, inputs_len, step)
92
+ num_samples = len(all_chunk_start_idx)
93
+ num_batches = math.ceil(num_samples / BATCH_SIZE)
94
+
95
+ dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
96
+ model_outputs = []
97
+ start_time = time.time()
98
+ logger.info("transcribing...")
99
+ # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
100
+ for batch in dataloader:
101
+ model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
102
+ runtime = time.time() - start_time
103
+ logger.info("done transcription")
104
+
105
+ logger.info("post-processing...")
106
+ post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
107
+ text = post_processed["text"]
108
+ if return_timestamps:
109
+ timestamps = post_processed.get("chunks")
110
+ timestamps = [
111
+ f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
112
+ for chunk in timestamps
113
+ ]
114
+ text = "\n".join(str(feature) for feature in timestamps)
115
+ logger.info("done post-processing")
116
+ return text, runtime
117
+
118
+ def _return_yt_html_embed(yt_url):
119
+ video_id = yt_url.split("?v=")[-1]
120
+ HTML_str = (
121
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
122
+ " </center>"
123
+ )
124
+ return HTML_str
125
+
126
+ def download_yt_audio(yt_url, filename):
127
+ info_loader = youtube_dl.YoutubeDL()
128
  try:
129
+ info = info_loader.extract_info(yt_url, download=False)
130
+ except youtube_dl.utils.DownloadError as err:
131
+ raise HTTPException(status_code=400, detail=str(err))
132
+
133
+ file_length = info["duration_string"]
134
+ file_h_m_s = file_length.split(":")
135
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
136
+ if len(file_h_m_s) == 1:
137
+ file_h_m_s.insert(0, 0)
138
+ if len(file_h_m_s) == 2:
139
+ file_h_m_s.insert(0, 0)
140
+
141
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
142
+ if file_length_s > YT_LENGTH_LIMIT_S:
143
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
144
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
145
+ raise HTTPException(status_code=400, detail=f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
146
+
147
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
148
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
149
  try:
150
+ ydl.download([yt_url])
151
+ except youtube_dl.utils.ExtractorError as err:
152
+ raise HTTPException(status_code=400, detail=str(err))