Afrinetwork7
commited on
Commit
•
2ecbad4
1
Parent(s):
8da2b37
Update app.py
Browse files
app.py
CHANGED
@@ -12,15 +12,6 @@ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
|
12 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
13 |
from whisper_jax import FlaxWhisperPipline
|
14 |
|
15 |
-
cc.initialize_cache("./jax_cache")
|
16 |
-
checkpoint = "openai/whisper-large-v3"
|
17 |
-
|
18 |
-
BATCH_SIZE = 32
|
19 |
-
CHUNK_LENGTH_S = 30
|
20 |
-
NUM_PROC = 32
|
21 |
-
FILE_LIMIT_MB = 10000
|
22 |
-
YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
|
23 |
-
|
24 |
app = FastAPI(title="Whisper JAX: The Fastest Whisper API ⚡️")
|
25 |
|
26 |
logger = logging.getLogger("whisper-jax-app")
|
@@ -31,6 +22,14 @@ formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d
|
|
31 |
ch.setFormatter(formatter)
|
32 |
logger.addHandler(ch)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
|
35 |
stride_length_s = CHUNK_LENGTH_S / 6
|
36 |
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
|
@@ -149,4 +148,23 @@ def download_yt_audio(yt_url, filename):
|
|
149 |
try:
|
150 |
ydl.download([yt_url])
|
151 |
except youtube_dl.utils.ExtractorError as err:
|
152 |
-
raise HTTPException(status_code=400, detail=str(err))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
13 |
from whisper_jax import FlaxWhisperPipline
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
app = FastAPI(title="Whisper JAX: The Fastest Whisper API ⚡️")
|
16 |
|
17 |
logger = logging.getLogger("whisper-jax-app")
|
|
|
22 |
ch.setFormatter(formatter)
|
23 |
logger.addHandler(ch)
|
24 |
|
25 |
+
checkpoint = "openai/whisper-large-v3"
|
26 |
+
|
27 |
+
BATCH_SIZE = 32
|
28 |
+
CHUNK_LENGTH_S = 30
|
29 |
+
NUM_PROC = 32
|
30 |
+
FILE_LIMIT_MB = 10000
|
31 |
+
YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
|
32 |
+
|
33 |
pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
|
34 |
stride_length_s = CHUNK_LENGTH_S / 6
|
35 |
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
|
|
|
148 |
try:
|
149 |
ydl.download([yt_url])
|
150 |
except youtube_dl.utils.ExtractorError as err:
|
151 |
+
raise HTTPException(status_code=400, detail=str(err))
|
152 |
+
|
153 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
|
154 |
+
if seconds is not None:
|
155 |
+
milliseconds = round(seconds * 1000.0)
|
156 |
+
|
157 |
+
hours = milliseconds // 3_600_000
|
158 |
+
milliseconds -= hours * 3_600_000
|
159 |
+
|
160 |
+
minutes = milliseconds // 60_000
|
161 |
+
milliseconds -= minutes * 60_000
|
162 |
+
|
163 |
+
seconds = milliseconds // 1_000
|
164 |
+
milliseconds -= seconds * 1_000
|
165 |
+
|
166 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
167 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
168 |
+
else:
|
169 |
+
# we have a malformed timestamp so just return it as is
|
170 |
+
return seconds
|