Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
import logging
|
3 |
import asyncio
|
4 |
import uvicorn
|
5 |
-
import torch
|
6 |
import random
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from fastapi import FastAPI, Query, HTTPException
|
@@ -33,20 +32,23 @@ async def load_models():
|
|
33 |
]
|
34 |
gpt_models = ["gpt2-medium", "gpt2-large", "gpt2", "gemma-2-9b", "starcoder"] + programming_models
|
35 |
|
|
|
36 |
for model_name in gpt_models:
|
37 |
try:
|
38 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
39 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
40 |
logger.info(f"Successfully loaded {model_name} model")
|
41 |
-
|
42 |
except Exception as e:
|
43 |
logger.error(f"Failed to load {model_name} model: {e}")
|
44 |
-
|
|
|
|
|
45 |
|
46 |
# Función para descargar modelos
|
47 |
async def download_models():
|
48 |
-
|
49 |
-
data_and_models_dict['
|
50 |
|
51 |
@app.get('/')
|
52 |
async def main():
|
@@ -153,9 +155,9 @@ async def main():
|
|
153 |
|
154 |
saveMessage('user', userMessage);
|
155 |
await fetch(`/autocomplete?q=${userMessage}`)
|
156 |
-
.then(response => response.
|
157 |
.then(data => {
|
158 |
-
saveMessage('bot', data);
|
159 |
chatBox.scrollTop = chatBox.scrollHeight;
|
160 |
})
|
161 |
.catch(error => console.error('Error:', error));
|
@@ -179,35 +181,47 @@ async def autocomplete(q: str = Query(...)):
|
|
179 |
global data_and_models_dict, message_history, tokens_history
|
180 |
|
181 |
# Verificar si hay modelos cargados
|
182 |
-
if '
|
183 |
await download_models()
|
184 |
|
185 |
-
#
|
186 |
-
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
# Guardar eos y pad tokens
|
213 |
eos_token = tokenizer.eos_token_id
|
@@ -217,12 +231,13 @@ async def autocomplete(q: str = Query(...)):
|
|
217 |
# Guardar el mensaje del usuario en el historial
|
218 |
message_history.append(q)
|
219 |
|
220 |
-
|
|
|
221 |
|
222 |
# Función para ejecutar la aplicación sin reiniciarla
|
223 |
def run_app():
|
224 |
asyncio.run(download_models())
|
225 |
-
uvicorn.run(app, host='0.0.0.0', port=
|
226 |
|
227 |
# Ejecutar la aplicación
|
228 |
if __name__ == "__main__":
|
|
|
2 |
import logging
|
3 |
import asyncio
|
4 |
import uvicorn
|
|
|
5 |
import random
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
from fastapi import FastAPI, Query, HTTPException
|
|
|
32 |
]
|
33 |
gpt_models = ["gpt2-medium", "gpt2-large", "gpt2", "gemma-2-9b", "starcoder"] + programming_models
|
34 |
|
35 |
+
models = []
|
36 |
for model_name in gpt_models:
|
37 |
try:
|
38 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
39 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
40 |
logger.info(f"Successfully loaded {model_name} model")
|
41 |
+
models.append((model, tokenizer, model_name))
|
42 |
except Exception as e:
|
43 |
logger.error(f"Failed to load {model_name} model: {e}")
|
44 |
+
if not models:
|
45 |
+
raise HTTPException(status_code=500, detail="Failed to load any models")
|
46 |
+
return models
|
47 |
|
48 |
# Función para descargar modelos
|
49 |
async def download_models():
|
50 |
+
models = await load_models()
|
51 |
+
data_and_models_dict['models'] = models
|
52 |
|
53 |
@app.get('/')
|
54 |
async def main():
|
|
|
155 |
|
156 |
saveMessage('user', userMessage);
|
157 |
await fetch(`/autocomplete?q=${userMessage}`)
|
158 |
+
.then(response => response.json())
|
159 |
.then(data => {
|
160 |
+
saveMessage('bot', data.response);
|
161 |
chatBox.scrollTop = chatBox.scrollHeight;
|
162 |
})
|
163 |
.catch(error => console.error('Error:', error));
|
|
|
181 |
global data_and_models_dict, message_history, tokens_history
|
182 |
|
183 |
# Verificar si hay modelos cargados
|
184 |
+
if 'models' not in data_and_models_dict:
|
185 |
await download_models()
|
186 |
|
187 |
+
# Obtener los modelos
|
188 |
+
models = data_and_models_dict['models']
|
189 |
|
190 |
+
best_response = None
|
191 |
+
best_score = float('-inf') # Para almacenar la mejor puntuación
|
192 |
+
|
193 |
+
for model, tokenizer, model_name in models:
|
194 |
+
# Generar tokens de entrada
|
195 |
+
input_ids = tokenizer.encode(q, return_tensors="pt")
|
196 |
+
tokens_history.append({"input": input_ids.tolist()}) # Guardar tokens de entrada
|
197 |
+
|
198 |
+
# Generar parámetros aleatorios
|
199 |
+
top_k = random.randint(0, 50)
|
200 |
+
top_p = random.uniform(0.8, 1.0)
|
201 |
+
temperature = random.uniform(0.7, 1.5)
|
202 |
+
|
203 |
+
# Generar una respuesta utilizando el modelo
|
204 |
+
output = model.generate(
|
205 |
+
input_ids,
|
206 |
+
max_length=50,
|
207 |
+
top_k=top_k,
|
208 |
+
top_p=top_p,
|
209 |
+
temperature=temperature,
|
210 |
+
num_return_sequences=1
|
211 |
+
)
|
212 |
+
response_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
213 |
+
|
214 |
+
# Calcular una puntuación simple para determinar la mejor respuesta
|
215 |
+
score = len(response_text) # Aquí podrías usar otro criterio de puntuación
|
216 |
+
|
217 |
+
# Comparar y almacenar la mejor respuesta
|
218 |
+
if score > best_score:
|
219 |
+
best_score = score
|
220 |
+
best_response = response_text
|
221 |
+
|
222 |
+
# Generar tokens de salida
|
223 |
+
output_ids = output[0].tolist()
|
224 |
+
tokens_history.append({"output": output_ids}) # Guardar tokens de salida
|
225 |
|
226 |
# Guardar eos y pad tokens
|
227 |
eos_token = tokenizer.eos_token_id
|
|
|
231 |
# Guardar el mensaje del usuario en el historial
|
232 |
message_history.append(q)
|
233 |
|
234 |
+
# Respuesta con la mejor respuesta generada
|
235 |
+
return {"response": best_response}
|
236 |
|
237 |
# Función para ejecutar la aplicación sin reiniciarla
|
238 |
def run_app():
|
239 |
asyncio.run(download_models())
|
240 |
+
uvicorn.run(app, host='0.0.0.0', port=7860)
|
241 |
|
242 |
# Ejecutar la aplicación
|
243 |
if __name__ == "__main__":
|