Afrinetwork7 commited on
Commit
d63d47a
1 Parent(s): b3a33de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -21
app.py CHANGED
@@ -15,9 +15,9 @@ from whisper_jax import FlaxWhisperPipline
15
  app = FastAPI(title="Whisper JAX: The Fastest Whisper API ⚡️")
16
 
17
  logger = logging.getLogger("whisper-jax-app")
18
- logger.setLevel(logging.INFO)
19
  ch = logging.StreamHandler()
20
- ch.setLevel(logging.INFO)
21
  formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
22
  ch.setFormatter(formatter)
23
  logger.addHandler(ch)
@@ -37,7 +37,7 @@ stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.
37
  step = chunk_len - stride_left - stride_right
38
 
39
  # do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
40
- logger.info("compiling forward call...")
41
  start = time.time()
42
  random_inputs = {
43
  "input_features": np.ones(
@@ -46,11 +46,11 @@ random_inputs = {
46
  }
47
  random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
48
  compile_time = time.time() - start
49
- logger.info(f"compiled in {compile_time}s")
50
 
51
  @app.post("/transcribe_audio")
52
  async def transcribe_chunked_audio(audio_file: UploadFile, task: str = "transcribe", return_timestamps: bool = False):
53
- logger.info("loading audio file...")
54
  if not audio_file:
55
  logger.warning("No audio file")
56
  raise HTTPException(status_code=400, detail="No audio file submitted!")
@@ -59,30 +59,60 @@ async def transcribe_chunked_audio(audio_file: UploadFile, task: str = "transcri
59
  logger.warning("Max file size exceeded")
60
  raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
61
 
62
- with open(audio_file.filename, "rb") as f:
63
- inputs = f.read()
 
 
 
 
64
 
65
  inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
66
  inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
67
- logger.info("done loading")
68
- text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
 
 
 
 
 
 
69
  return {"text": text, "runtime": runtime}
70
 
71
  @app.post("/transcribe_youtube")
72
  async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe", return_timestamps: bool = False):
73
- logger.info("loading youtube file...")
74
- html_embed_str = _return_yt_html_embed(yt_url)
 
 
 
 
 
75
  with tempfile.TemporaryDirectory() as tmpdirname:
76
  filepath = os.path.join(tmpdirname, "video.mp4")
77
- download_yt_audio(yt_url, filepath)
 
 
 
 
 
78
 
79
- with open(filepath, "rb") as f:
80
- inputs = f.read()
 
 
 
 
81
 
82
  inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
83
  inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
84
- logger.info("done loading...")
85
- text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
 
 
 
 
 
 
86
  return {"html_embed": html_embed_str, "text": text, "runtime": runtime}
87
 
88
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
@@ -94,15 +124,19 @@ def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
94
  dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
95
  model_outputs = []
96
  start_time = time.time()
97
- logger.info("transcribing...")
98
  # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
99
  for batch in dataloader:
100
  model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
101
  runtime = time.time() - start_time
102
- logger.info("done transcription")
103
 
104
- logger.info("post-processing...")
105
- post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
 
 
 
 
106
  text = post_processed["text"]
107
  if return_timestamps:
108
  timestamps = post_processed.get("chunks")
@@ -111,7 +145,7 @@ def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
111
  for chunk in timestamps
112
  ]
113
  text = "\n".join(str(feature) for feature in timestamps)
114
- logger.info("done post-processing")
115
  return text, runtime
116
 
117
  def _return_yt_html_embed(yt_url):
@@ -125,8 +159,10 @@ def _return_yt_html_embed(yt_url):
125
  def download_yt_audio(yt_url, filename):
126
  info_loader = youtube_dl.YoutubeDL()
127
  try:
 
128
  info = info_loader.extract_info(yt_url, download=False)
129
  except youtube_dl.utils.DownloadError as err:
 
130
  raise HTTPException(status_code=400, detail=str(err))
131
 
132
  file_length = info["duration_string"]
@@ -146,8 +182,10 @@ def download_yt_audio(yt_url, filename):
146
  ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
147
  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
148
  try:
 
149
  ydl.download([yt_url])
150
  except youtube_dl.utils.ExtractorError as err:
 
151
  raise HTTPException(status_code=400, detail=str(err))
152
 
153
  def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
 
15
  app = FastAPI(title="Whisper JAX: The Fastest Whisper API ⚡️")
16
 
17
  logger = logging.getLogger("whisper-jax-app")
18
+ logger.setLevel(logging.DEBUG)
19
  ch = logging.StreamHandler()
20
+ ch.setLevel(logging.DEBUG)
21
  formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
22
  ch.setFormatter(formatter)
23
  logger.addHandler(ch)
 
37
  step = chunk_len - stride_left - stride_right
38
 
39
  # do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
40
+ logger.debug("Compiling forward call...")
41
  start = time.time()
42
  random_inputs = {
43
  "input_features": np.ones(
 
46
  }
47
  random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
48
  compile_time = time.time() - start
49
+ logger.debug(f"Compiled in {compile_time}s")
50
 
51
  @app.post("/transcribe_audio")
52
  async def transcribe_chunked_audio(audio_file: UploadFile, task: str = "transcribe", return_timestamps: bool = False):
53
+ logger.debug("Loading audio file...")
54
  if not audio_file:
55
  logger.warning("No audio file")
56
  raise HTTPException(status_code=400, detail="No audio file submitted!")
 
59
  logger.warning("Max file size exceeded")
60
  raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
61
 
62
+ try:
63
+ with open(audio_file.filename, "rb") as f:
64
+ inputs = f.read()
65
+ except Exception as e:
66
+ logger.error("Error reading audio file:", exc_info=True)
67
+ raise HTTPException(status_code=500, detail="Error reading audio file")
68
 
69
  inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
70
  inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
71
+ logger.debug("Done loading audio file")
72
+
73
+ try:
74
+ text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
75
+ except Exception as e:
76
+ logger.error("Error transcribing audio:", exc_info=True)
77
+ raise HTTPException(status_code=500, detail="Error transcribing audio")
78
+
79
  return {"text": text, "runtime": runtime}
80
 
81
  @app.post("/transcribe_youtube")
82
  async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe", return_timestamps: bool = False):
83
+ logger.debug("Loading YouTube file...")
84
+ try:
85
+ html_embed_str = _return_yt_html_embed(yt_url)
86
+ except Exception as e:
87
+ logger.error("Error generating YouTube HTML embed:", exc_info=True)
88
+ raise HTTPException(status_code=500, detail="Error generating YouTube HTML embed")
89
+
90
  with tempfile.TemporaryDirectory() as tmpdirname:
91
  filepath = os.path.join(tmpdirname, "video.mp4")
92
+ try:
93
+ logger.debug("Downloading YouTube audio...")
94
+ download_yt_audio(yt_url, filepath)
95
+ except Exception as e:
96
+ logger.error("Error downloading YouTube audio:", exc_info=True)
97
+ raise HTTPException(status_code=500, detail="Error downloading YouTube audio")
98
 
99
+ try:
100
+ with open(filepath, "rb") as f:
101
+ inputs = f.read()
102
+ except Exception as e:
103
+ logger.error("Error reading downloaded audio file:", exc_info=True)
104
+ raise HTTPException(status_code=500, detail="Error reading downloaded audio file")
105
 
106
  inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
107
  inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
108
+ logger.debug("Done loading YouTube file")
109
+
110
+ try:
111
+ text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
112
+ except Exception as e:
113
+ logger.error("Error transcribing YouTube audio:", exc_info=True)
114
+ raise HTTPException(status_code=500, detail="Error transcribing YouTube audio")
115
+
116
  return {"html_embed": html_embed_str, "text": text, "runtime": runtime}
117
 
118
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
 
124
  dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
125
  model_outputs = []
126
  start_time = time.time()
127
+ logger.debug("Transcribing...")
128
  # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
129
  for batch in dataloader:
130
  model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
131
  runtime = time.time() - start_time
132
+ logger.debug("Done transcription")
133
 
134
+ logger.debug("Post-processing...")
135
+ try:
136
+ post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
137
+ except Exception as e:
138
+ logger.error("Error post-processing transcription:", exc_info=True)
139
+ raise HTTPException(status_code=500, detail="Error post-processing transcription")
140
  text = post_processed["text"]
141
  if return_timestamps:
142
  timestamps = post_processed.get("chunks")
 
145
  for chunk in timestamps
146
  ]
147
  text = "\n".join(str(feature) for feature in timestamps)
148
+ logger.debug("Done post-processing")
149
  return text, runtime
150
 
151
  def _return_yt_html_embed(yt_url):
 
159
  def download_yt_audio(yt_url, filename):
160
  info_loader = youtube_dl.YoutubeDL()
161
  try:
162
+ logger.debug(f"Extracting info for YouTube URL: {yt_url}")
163
  info = info_loader.extract_info(yt_url, download=False)
164
  except youtube_dl.utils.DownloadError as err:
165
+ logger.error("Error extracting YouTube info:", exc_info=True)
166
  raise HTTPException(status_code=400, detail=str(err))
167
 
168
  file_length = info["duration_string"]
 
182
  ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
183
  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
184
  try:
185
+ logger.debug(f"Downloading YouTube audio to {filename}")
186
  ydl.download([yt_url])
187
  except youtube_dl.utils.ExtractorError as err:
188
+ logger.error("Error downloading YouTube audio:", exc_info=True)
189
  raise HTTPException(status_code=400, detail=str(err))
190
 
191
  def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):