Yhhxhfh commited on
Commit
c89a827
1 Parent(s): 4605982

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -35
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import logging
3
  import asyncio
4
  import uvicorn
5
- import torch
6
  import random
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from fastapi import FastAPI, Query, HTTPException
@@ -33,20 +32,23 @@ async def load_models():
33
  ]
34
  gpt_models = ["gpt2-medium", "gpt2-large", "gpt2", "gemma-2-9b", "starcoder"] + programming_models
35
 
 
36
  for model_name in gpt_models:
37
  try:
38
  model = AutoModelForCausalLM.from_pretrained(model_name)
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  logger.info(f"Successfully loaded {model_name} model")
41
- return model, tokenizer
42
  except Exception as e:
43
  logger.error(f"Failed to load {model_name} model: {e}")
44
- raise HTTPException(status_code=500, detail="Failed to load any models")
 
 
45
 
46
  # Función para descargar modelos
47
  async def download_models():
48
- model, tokenizer = await load_models()
49
- data_and_models_dict['model'] = (model, tokenizer)
50
 
51
  @app.get('/')
52
  async def main():
@@ -153,9 +155,9 @@ async def main():
153
 
154
  saveMessage('user', userMessage);
155
  await fetch(`/autocomplete?q=${userMessage}`)
156
- .then(response => response.text())
157
  .then(data => {
158
- saveMessage('bot', data);
159
  chatBox.scrollTop = chatBox.scrollHeight;
160
  })
161
  .catch(error => console.error('Error:', error));
@@ -179,35 +181,47 @@ async def autocomplete(q: str = Query(...)):
179
  global data_and_models_dict, message_history, tokens_history
180
 
181
  # Verificar si hay modelos cargados
182
- if 'model' not in data_and_models_dict:
183
  await download_models()
184
 
185
- # Cargar el modelo y el tokenizer
186
- model, tokenizer = data_and_models_dict['model']
187
 
188
- # Generar tokens de entrada
189
- input_ids = tokenizer.encode(q, return_tensors="pt")
190
- tokens_history.append({"input": input_ids.tolist()}) # Guardar tokens de entrada
191
-
192
- # Generar parámetros aleatorios
193
- top_k = random.randint(0, 50)
194
- top_p = random.uniform(0.8, 1.0)
195
- temperature = random.uniform(0.7, 1.5)
196
-
197
- # Generar una respuesta utilizando el modelo
198
- output = model.generate(
199
- input_ids,
200
- max_length=50,
201
- top_k=top_k,
202
- top_p=top_p,
203
- temperature=temperature,
204
- num_return_sequences=1
205
- )
206
- response_text = tokenizer.decode(output[0], skip_special_tokens=True)
207
-
208
- # Generar tokens de salida
209
- output_ids = output[0].tolist()
210
- tokens_history.append({"output": output_ids}) # Guardar tokens de salida
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  # Guardar eos y pad tokens
213
  eos_token = tokenizer.eos_token_id
@@ -217,12 +231,13 @@ async def autocomplete(q: str = Query(...)):
217
  # Guardar el mensaje del usuario en el historial
218
  message_history.append(q)
219
 
220
- return response_text
 
221
 
222
  # Función para ejecutar la aplicación sin reiniciarla
223
  def run_app():
224
  asyncio.run(download_models())
225
- uvicorn.run(app, host='0.0.0.0', port=4443)
226
 
227
  # Ejecutar la aplicación
228
  if __name__ == "__main__":
 
2
  import logging
3
  import asyncio
4
  import uvicorn
 
5
  import random
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from fastapi import FastAPI, Query, HTTPException
 
32
  ]
33
  gpt_models = ["gpt2-medium", "gpt2-large", "gpt2", "gemma-2-9b", "starcoder"] + programming_models
34
 
35
+ models = []
36
  for model_name in gpt_models:
37
  try:
38
  model = AutoModelForCausalLM.from_pretrained(model_name)
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  logger.info(f"Successfully loaded {model_name} model")
41
+ models.append((model, tokenizer, model_name))
42
  except Exception as e:
43
  logger.error(f"Failed to load {model_name} model: {e}")
44
+ if not models:
45
+ raise HTTPException(status_code=500, detail="Failed to load any models")
46
+ return models
47
 
48
  # Función para descargar modelos
49
  async def download_models():
50
+ models = await load_models()
51
+ data_and_models_dict['models'] = models
52
 
53
  @app.get('/')
54
  async def main():
 
155
 
156
  saveMessage('user', userMessage);
157
  await fetch(`/autocomplete?q=${userMessage}`)
158
+ .then(response => response.json())
159
  .then(data => {
160
+ saveMessage('bot', data.response);
161
  chatBox.scrollTop = chatBox.scrollHeight;
162
  })
163
  .catch(error => console.error('Error:', error));
 
181
  global data_and_models_dict, message_history, tokens_history
182
 
183
  # Verificar si hay modelos cargados
184
+ if 'models' not in data_and_models_dict:
185
  await download_models()
186
 
187
+ # Obtener los modelos
188
+ models = data_and_models_dict['models']
189
 
190
+ best_response = None
191
+ best_score = float('-inf') # Para almacenar la mejor puntuación
192
+
193
+ for model, tokenizer, model_name in models:
194
+ # Generar tokens de entrada
195
+ input_ids = tokenizer.encode(q, return_tensors="pt")
196
+ tokens_history.append({"input": input_ids.tolist()}) # Guardar tokens de entrada
197
+
198
+ # Generar parámetros aleatorios
199
+ top_k = random.randint(0, 50)
200
+ top_p = random.uniform(0.8, 1.0)
201
+ temperature = random.uniform(0.7, 1.5)
202
+
203
+ # Generar una respuesta utilizando el modelo
204
+ output = model.generate(
205
+ input_ids,
206
+ max_length=50,
207
+ top_k=top_k,
208
+ top_p=top_p,
209
+ temperature=temperature,
210
+ num_return_sequences=1
211
+ )
212
+ response_text = tokenizer.decode(output[0], skip_special_tokens=True)
213
+
214
+ # Calcular una puntuación simple para determinar la mejor respuesta
215
+ score = len(response_text) # Aquí podrías usar otro criterio de puntuación
216
+
217
+ # Comparar y almacenar la mejor respuesta
218
+ if score > best_score:
219
+ best_score = score
220
+ best_response = response_text
221
+
222
+ # Generar tokens de salida
223
+ output_ids = output[0].tolist()
224
+ tokens_history.append({"output": output_ids}) # Guardar tokens de salida
225
 
226
  # Guardar eos y pad tokens
227
  eos_token = tokenizer.eos_token_id
 
231
  # Guardar el mensaje del usuario en el historial
232
  message_history.append(q)
233
 
234
+ # Respuesta con la mejor respuesta generada
235
+ return {"response": best_response}
236
 
237
  # Función para ejecutar la aplicación sin reiniciarla
238
  def run_app():
239
  asyncio.run(download_models())
240
+ uvicorn.run(app, host='0.0.0.0', port=7860)
241
 
242
  # Ejecutar la aplicación
243
  if __name__ == "__main__":