Afrinetwork7 commited on
Commit
92975f5
1 Parent(s): 23b21ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -137
app.py CHANGED
@@ -17,40 +17,41 @@ from transformers.pipelines.audio_utils import ffmpeg_read
17
 
18
  from whisper_jax import FlaxWhisperPipline
19
 
20
- cc.initialize_cache("./jax_cache")
21
- checkpoint = "openai/whisper-large-v3"
 
22
 
23
- BATCH_SIZE = 32
24
- CHUNK_LENGTH_S = 30
25
- NUM_PROC = 32
26
- FILE_LIMIT_MB = 10000
27
- YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- logger = logging.getLogger("whisper-jax-app")
30
- logger.setLevel(logging.INFO)
31
- ch = logging.StreamHandler()
32
- ch.setLevel(logging.INFO)
33
- formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
34
- ch.setFormatter(formatter)
35
- logger.addHandler(ch)
36
-
37
- pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
38
- stride_length_s = CHUNK_LENGTH_S / 6
39
- chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
40
- stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
41
- step = chunk_len - stride_left - stride_right
42
-
43
- # do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
44
- logger.info("compiling forward call...")
45
- start = time.time()
46
- random_inputs = {
47
- "input_features": np.ones(
48
- (BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions)
49
- )
50
- }
51
- random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
52
- compile_time = time.time() - start
53
- logger.info(f"compiled in {compile_time}s")
54
 
55
  app = fastapi.FastAPI()
56
 
@@ -65,128 +66,150 @@ class TranscriptionResponse(BaseModel):
65
 
66
  @app.post("/transcribe", response_model=TranscriptionResponse)
67
  def transcribe_audio(request: TranscriptionRequest):
68
- logger.info("loading audio file...")
69
- if not request.audio_file:
70
- logger.warning("No audio file")
71
- raise fastapi.HTTPException(status_code=400, detail="No audio file submitted!")
72
-
73
- audio_bytes = base64.b64decode(request.audio_file)
74
- file_size_mb = len(audio_bytes) / (1024 * 1024)
75
- if file_size_mb > FILE_LIMIT_MB:
76
- logger.warning("Max file size exceeded")
77
- raise fastapi.HTTPException(
78
- status_code=400,
79
- detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.",
80
- )
81
-
82
- inputs = ffmpeg_read(audio_bytes, pipeline.feature_extractor.sampling_rate)
83
- inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
84
- logger.info("done loading")
85
- text, runtime = _tqdm_generate(inputs, task=request.task, return_timestamps=request.return_timestamps)
86
- return TranscriptionResponse(transcription=text, runtime=runtime)
 
 
 
 
87
 
88
  @app.post("/transcribe_youtube")
89
  def transcribe_youtube(
90
  yt_url: str, task: str = "transcribe", return_timestamps: bool = False
91
  ) -> Tuple[str, str, float]:
92
- logger.info("loading youtube file...")
93
- html_embed_str = _return_yt_html_embed(yt_url)
94
- with tempfile.TemporaryDirectory() as tmpdirname:
95
- filepath = os.path.join(tmpdirname, "video.mp4")
96
- _download_yt_audio(yt_url, filepath)
97
-
98
- with open(filepath, "rb") as f:
99
- inputs = f.read()
100
-
101
- inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
102
- inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
103
- logger.info("done loading...")
104
- text, runtime = _tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
105
- return html_embed_str, text, runtime
 
 
 
 
106
 
107
  def _tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
108
- inputs_len = inputs["array"].shape[0]
109
- all_chunk_start_idx = np.arange(0, inputs_len, step)
110
- num_samples = len(all_chunk_start_idx)
111
- num_batches = math.ceil(num_samples / BATCH_SIZE)
112
-
113
- dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
114
- model_outputs = []
115
- start_time = time.time()
116
- logger.info("transcribing...")
117
- # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
118
- for batch, _ in zip(dataloader, range(num_batches)):
119
- model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
120
- runtime = time.time() - start_time
121
- logger.info("done transcription")
122
-
123
- logger.info("post-processing...")
124
- post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
125
- text = post_processed["text"]
126
- if return_timestamps:
127
- timestamps = post_processed.get("chunks")
128
- timestamps = [
129
- f"[{_format_timestamp(chunk['timestamp'][0])} -> {_format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
130
- for chunk in timestamps
131
- ]
132
- text = "\n".join(str(feature) for feature in timestamps)
133
- logger.info("done post-processing")
134
- return text, runtime
 
 
 
135
 
136
  def _return_yt_html_embed(yt_url: str) -> str:
137
- video_id = yt_url.split("?v=")[-1]
138
- HTML_str = (
139
- f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
140
- " </center>"
141
- )
142
- return HTML_str
143
-
144
- def _download_yt_audio(yt_url: str, filename: str):
145
- info_loader = youtube_dl.YoutubeDL()
146
  try:
147
- info = info_loader.extract_info(yt_url, download=False)
148
- except youtube_dl.utils.DownloadError as err:
149
- raise fastapi.HTTPException(status_code=400, detail=str(err))
150
-
151
- file_length = info["duration_string"]
152
- file_h_m_s = file_length.split(":")
153
- file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
154
- if len(file_h_m_s) == 1:
155
- file_h_m_s.insert(0, 0)
156
- if len(file_h_m_s) == 2:
157
- file_h_m_s.insert(0, 0)
158
-
159
- file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
160
- if file_length_s > YT_LENGTH_LIMIT_S:
161
- yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
162
- file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
163
- raise fastapi.HTTPException(
164
- status_code=400,
165
- detail=f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.",
166
  )
 
 
 
 
167
 
168
- ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
169
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
 
170
  try:
171
- ydl.download([yt_url])
172
- except youtube_dl.utils.ExtractorError as err:
173
  raise fastapi.HTTPException(status_code=400, detail=str(err))
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def _format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
176
- if seconds is not None:
177
- milliseconds = round(seconds * 1000.0)
 
178
 
179
- hours = milliseconds // 3_600_000
180
- milliseconds -= hours * 3_600_000
181
 
182
- minutes = milliseconds // 60_000
183
- milliseconds -= minutes * 60_000
184
 
185
- seconds = milliseconds // 1_000
186
- milliseconds -= seconds * 1_000
187
 
188
- hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
189
- return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
190
- else:
191
- # we have a malformed timestamp so just return it as is
192
- return seconds
 
 
 
17
 
18
  from whisper_jax import FlaxWhisperPipline
19
 
20
+ # Set up logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger("whisper-jax-app")
23
 
24
+ try:
25
+ cc.initialize_cache("./jax_cache")
26
+ checkpoint = "openai/whisper-large-v3"
27
+
28
+ BATCH_SIZE = 32
29
+ CHUNK_LENGTH_S = 30
30
+ NUM_PROC = 32
31
+ FILE_LIMIT_MB = 10000
32
+ YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
33
+
34
+ pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
35
+ stride_length_s = CHUNK_LENGTH_S / 6
36
+ chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
37
+ stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
38
+ step = chunk_len - stride_left - stride_right
39
+
40
+ # do a pre-compile step
41
+ logger.info("compiling forward call...")
42
+ start = time.time()
43
+ random_inputs = {
44
+ "input_features": np.ones(
45
+ (BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions)
46
+ )
47
+ }
48
+ random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
49
+ compile_time = time.time() - start
50
+ logger.info(f"compiled in {compile_time}s")
51
 
52
+ except Exception as e:
53
+ logger.error(f"Error during initialization: {str(e)}")
54
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  app = fastapi.FastAPI()
57
 
 
66
 
67
  @app.post("/transcribe", response_model=TranscriptionResponse)
68
  def transcribe_audio(request: TranscriptionRequest):
69
+ try:
70
+ logger.info("loading audio file...")
71
+ if not request.audio_file:
72
+ logger.warning("No audio file")
73
+ raise fastapi.HTTPException(status_code=400, detail="No audio file submitted!")
74
+
75
+ audio_bytes = base64.b64decode(request.audio_file)
76
+ file_size_mb = len(audio_bytes) / (1024 * 1024)
77
+ if file_size_mb > FILE_LIMIT_MB:
78
+ logger.warning("Max file size exceeded")
79
+ raise fastapi.HTTPException(
80
+ status_code=400,
81
+ detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.",
82
+ )
83
+
84
+ inputs = ffmpeg_read(audio_bytes, pipeline.feature_extractor.sampling_rate)
85
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
86
+ logger.info("done loading")
87
+ text, runtime = _tqdm_generate(inputs, task=request.task, return_timestamps=request.return_timestamps)
88
+ return TranscriptionResponse(transcription=text, runtime=runtime)
89
+ except Exception as e:
90
+ logger.error(f"Error in transcribe_audio: {str(e)}")
91
+ raise fastapi.HTTPException(status_code=500, detail=f"An error occurred during transcription: {str(e)}")
92
 
93
  @app.post("/transcribe_youtube")
94
  def transcribe_youtube(
95
  yt_url: str, task: str = "transcribe", return_timestamps: bool = False
96
  ) -> Tuple[str, str, float]:
97
+ try:
98
+ logger.info("loading youtube file...")
99
+ html_embed_str = _return_yt_html_embed(yt_url)
100
+ with tempfile.TemporaryDirectory() as tmpdirname:
101
+ filepath = os.path.join(tmpdirname, "video.mp4")
102
+ _download_yt_audio(yt_url, filepath)
103
+
104
+ with open(filepath, "rb") as f:
105
+ inputs = f.read()
106
+
107
+ inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
108
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
109
+ logger.info("done loading...")
110
+ text, runtime = _tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
111
+ return html_embed_str, text, runtime
112
+ except Exception as e:
113
+ logger.error(f"Error in transcribe_youtube: {str(e)}")
114
+ raise fastapi.HTTPException(status_code=500, detail=f"An error occurred during YouTube transcription: {str(e)}")
115
 
116
  def _tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
117
+ try:
118
+ inputs_len = inputs["array"].shape[0]
119
+ all_chunk_start_idx = np.arange(0, inputs_len, step)
120
+ num_samples = len(all_chunk_start_idx)
121
+ num_batches = math.ceil(num_samples / BATCH_SIZE)
122
+
123
+ dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
124
+ model_outputs = []
125
+ start_time = time.time()
126
+ logger.info("transcribing...")
127
+ for batch, _ in zip(dataloader, range(num_batches)):
128
+ model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
129
+ runtime = time.time() - start_time
130
+ logger.info("done transcription")
131
+
132
+ logger.info("post-processing...")
133
+ post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
134
+ text = post_processed["text"]
135
+ if return_timestamps:
136
+ timestamps = post_processed.get("chunks")
137
+ timestamps = [
138
+ f"[{_format_timestamp(chunk['timestamp'][0])} -> {_format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
139
+ for chunk in timestamps
140
+ ]
141
+ text = "\n".join(str(feature) for feature in timestamps)
142
+ logger.info("done post-processing")
143
+ return text, runtime
144
+ except Exception as e:
145
+ logger.error(f"Error in _tqdm_generate: {str(e)}")
146
+ raise
147
 
148
  def _return_yt_html_embed(yt_url: str) -> str:
 
 
 
 
 
 
 
 
 
149
  try:
150
+ video_id = yt_url.split("?v=")[-1]
151
+ HTML_str = (
152
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
153
+ " </center>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  )
155
+ return HTML_str
156
+ except Exception as e:
157
+ logger.error(f"Error in _return_yt_html_embed: {str(e)}")
158
+ raise
159
 
160
+ def _download_yt_audio(yt_url: str, filename: str):
161
+ try:
162
+ info_loader = youtube_dl.YoutubeDL()
163
  try:
164
+ info = info_loader.extract_info(yt_url, download=False)
165
+ except youtube_dl.utils.DownloadError as err:
166
  raise fastapi.HTTPException(status_code=400, detail=str(err))
167
 
168
+ file_length = info["duration_string"]
169
+ file_h_m_s = file_length.split(":")
170
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
171
+ if len(file_h_m_s) == 1:
172
+ file_h_m_s.insert(0, 0)
173
+ if len(file_h_m_s) == 2:
174
+ file_h_m_s.insert(0, 0)
175
+
176
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
177
+ if file_length_s > YT_LENGTH_LIMIT_S:
178
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
179
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
180
+ raise fastapi.HTTPException(
181
+ status_code=400,
182
+ detail=f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.",
183
+ )
184
+
185
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
186
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
187
+ try:
188
+ ydl.download([yt_url])
189
+ except youtube_dl.utils.ExtractorError as err:
190
+ raise fastapi.HTTPException(status_code=400, detail=str(err))
191
+ except Exception as e:
192
+ logger.error(f"Error in _download_yt_audio: {str(e)}")
193
+ raise
194
+
195
  def _format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
196
+ try:
197
+ if seconds is not None:
198
+ milliseconds = round(seconds * 1000.0)
199
 
200
+ hours = milliseconds // 3_600_000
201
+ milliseconds -= hours * 3_600_000
202
 
203
+ minutes = milliseconds // 60_000
204
+ milliseconds -= minutes * 60_000
205
 
206
+ seconds = milliseconds // 1_000
207
+ milliseconds -= seconds * 1_000
208
 
209
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
210
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
211
+ else:
212
+ return seconds
213
+ except Exception as e:
214
+ logger.error(f"Error in _format_timestamp: {str(e)}")
215
+ raise