Leandro Artaza commited on
Commit
c97ee64
1 Parent(s): 9b18094

Add main.py

Browse files
Files changed (1) hide show
  1. main.py +83 -0
main.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import numpy as np
4
+ import librosa
5
+ import re
6
+ from unidecode import unidecode
7
+ import base64
8
+
9
+ app = FastAPI()
10
+
11
+ class AudioBytesEncoded(BaseModel):
12
+ audio_bytes_encoded: str
13
+
14
+ def get_model():
15
+ import faster_whisper
16
+ return faster_whisper.WhisperModel('tiny')
17
+
18
+ model_faster = get_model()
19
+
20
+ def CorregirErrores(texto):
21
+ for key, value in {'dinea': 'linea', 'dino': 'linea', 'tos': 'dos', 'dra': 'tra', 'una': 'uno', 'tes': 'tres', '1': 'uno', '2': 'dos', '3': 'tres'}.items():
22
+ texto = texto.replace(key, value)
23
+ return texto
24
+
25
+ comando_base = ['linea', 'tra']
26
+ comando_num = ['uno', 'dos', 'tres']
27
+
28
+ pattern = '|'.join([re.escape(word) for word in comando_base + comando_num])
29
+
30
+ nombre_clases = ['linea_uno', 'linea_dos', 'linea_tres', 'tra_uno', 'tra_dos', 'tra_tres']
31
+ def predecir(audio):
32
+ resultado_final = None
33
+
34
+ completado = False
35
+ params1 = {'initial_prompt': 'Línea 1. Línea 2. Línea 3. Tra 1. Tra 2. Tra 3.',
36
+ 'suppress_tokens': [],
37
+ 'repetition_penalty': 2,
38
+ 'no_speech_threshold': 0.1,
39
+ 'log_prob_threshold': -0.1}
40
+ params2 = {'initial_prompt': [],
41
+ 'suppress_tokens': [],
42
+ 'repetition_penalty': 2,}
43
+ for params in (params1, params2):
44
+ for temp in [0, 1.0]:
45
+ resultado_original = model_faster.transcribe(audio, language='es', temperature=temp, **params)[0]
46
+ try:
47
+ resultado_original = next(resultado_original).text
48
+ except:
49
+ print('Falló la conversion.')
50
+ continue
51
+ print('Predicción:\t', resultado_original, end='\n')
52
+ resultado = unidecode(resultado_original.lower().strip())
53
+
54
+ resultado = CorregirErrores(resultado)
55
+ for resultado in resultado.split('.'):
56
+ matches = re.findall(pattern, resultado)
57
+
58
+ resultado_final = '_'.join(matches)
59
+ if resultado_final in nombre_clases:
60
+ completado = True
61
+ break
62
+ if completado:
63
+ break
64
+ if completado:
65
+ break
66
+
67
+ if resultado_final not in nombre_clases:
68
+ resultado = 'Comando no reconocido.'
69
+ return resultado
70
+
71
+ @app.post("/predict/")
72
+ async def predict(audio_bytes_encoded: AudioBytesEncoded):
73
+ try:
74
+ audio_bytes = base64.b64decode(audio_bytes_encoded.audio_bytes_encoded)
75
+ audio_np = np.frombuffer(audio_bytes, dtype=np.float32)
76
+ audio_np = librosa.util.normalize(audio_np)
77
+
78
+ prediction = predecir(audio_np)
79
+ return {"prediction": prediction}
80
+ except Exception as e:
81
+ print(f"An error occurred: {e}")
82
+ raise HTTPException(status_code=500, detail=str(e))
83
+