Spaces:
Edmond98
/
Running on A100

Edmond7 commited on
Commit
c0c3100
1 Parent(s): 4b305c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -9
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
3
  from fastapi.security.api_key import APIKeyHeader, APIKey
4
  from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
@@ -70,9 +70,6 @@ class TTSRequest(BaseModel):
70
  class LanguageRequest(BaseModel):
71
  language: Optional[str] = None
72
 
73
- class TranscribeFileRequest(BaseModel):
74
- language: Optional[str] = None
75
-
76
  async def get_api_key(api_key_header: str = Security(api_key_header)):
77
  if api_key_header == API_KEY:
78
  return api_key_header
@@ -165,7 +162,11 @@ async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_
165
  )
166
 
167
  @app.post("/transcribe_file")
168
- async def transcribe_audio_file(file: UploadFile = File(...), request: TranscribeFileRequest = Depends(), api_key: APIKey = Depends(get_api_key)):
 
 
 
 
169
  start_time = time.time()
170
  try:
171
  contents = await file.read()
@@ -178,12 +179,12 @@ async def transcribe_audio_file(file: UploadFile = File(...), request: Transcrib
178
  if sample_rate != ASR_SAMPLING_RATE:
179
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
180
 
181
- if request.language is None:
182
  # If no language is provided, use language identification
183
  identified_language = identify(audio_array)
184
  result = transcribe(audio_array, identified_language)
185
  else:
186
- result = transcribe(audio_array, request.language)
187
 
188
  processing_time = time.time() - start_time
189
  return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
@@ -321,7 +322,10 @@ async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get
321
  )
322
 
323
  @app.post("/identify_file")
324
- async def identify_language_file(file: UploadFile = File(...), api_key: APIKey = Depends(get_api_key)):
 
 
 
325
  start_time = time.time()
326
  try:
327
  contents = await file.read()
@@ -339,7 +343,7 @@ async def identify_language_file(file: UploadFile = File(...), api_key: APIKey =
339
  return JSONResponse(
340
  status_code=500,
341
  content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time}
342
- )
343
 
344
  @app.post("/asr_languages")
345
  async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
@@ -388,3 +392,26 @@ async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(
388
  status_code=500,
389
  content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time}
390
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security, Form
3
  from fastapi.security.api_key import APIKeyHeader, APIKey
4
  from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
 
70
  class LanguageRequest(BaseModel):
71
  language: Optional[str] = None
72
 
 
 
 
73
  async def get_api_key(api_key_header: str = Security(api_key_header)):
74
  if api_key_header == API_KEY:
75
  return api_key_header
 
162
  )
163
 
164
  @app.post("/transcribe_file")
165
+ async def transcribe_audio_file(
166
+ file: UploadFile = File(...),
167
+ language: Optional[str] = Form(None),
168
+ api_key: APIKey = Depends(get_api_key)
169
+ ):
170
  start_time = time.time()
171
  try:
172
  contents = await file.read()
 
179
  if sample_rate != ASR_SAMPLING_RATE:
180
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
181
 
182
+ if language is None:
183
  # If no language is provided, use language identification
184
  identified_language = identify(audio_array)
185
  result = transcribe(audio_array, identified_language)
186
  else:
187
+ result = transcribe(audio_array, language)
188
 
189
  processing_time = time.time() - start_time
190
  return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
 
322
  )
323
 
324
  @app.post("/identify_file")
325
+ async def identify_language_file(
326
+ file: UploadFile = File(...),
327
+ api_key: APIKey = Depends(get_api_key)
328
+ ):
329
  start_time = time.time()
330
  try:
331
  contents = await file.read()
 
343
  return JSONResponse(
344
  status_code=500,
345
  content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time}
346
+ # ... (previous code remains the same)
347
 
348
  @app.post("/asr_languages")
349
  async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
 
392
  status_code=500,
393
  content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time}
394
  )
395
+
396
+ # If you want to add a health check endpoint
397
+ @app.get("/health")
398
+ async def health_check():
399
+ return {"status": "ok"}
400
+
401
+ # You might also want to add a root endpoint that provides basic API information
402
+ @app.get("/")
403
+ async def root():
404
+ return {
405
+ "message": "Welcome to the MMS Speech Technology API",
406
+ "version": "1.0",
407
+ "endpoints": [
408
+ "/transcribe",
409
+ "/transcribe_file",
410
+ "/synthesize",
411
+ "/identify",
412
+ "/identify_file",
413
+ "/asr_languages",
414
+ "/tts_languages",
415
+ "/health"
416
+ ]
417
+ }