Gxhhfhdhd / app.py
Yhhxhfh's picture
Update app.py
2c7a039 verified
import os
import logging
import asyncio
import uvicorn
import random
from transformers import AutoModelForCausalLM, AutoTokenizer
from fastapi import FastAPI, Query, HTTPException
from fastapi.responses import HTMLResponse
# Configuraci贸n de logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Inicializar la aplicaci贸n FastAPI
app = FastAPI()
# Diccionario para almacenar los modelos
data_and_models_dict = {}
# Lista para almacenar el historial de mensajes
message_history = []
# Lista para almacenar los tokens
tokens_history = []
# Funci贸n para cargar modelos
async def load_models():
programming_models = [
"microsoft/CodeGPT-small-py",
"Salesforce/codegen-350M-multi",
"Salesforce/codegen-2B-multi"
]
gpt_models = ["gpt2-medium", "gpt2-large", "gpt2", "google/gemma-2-9b"] + programming_models
models = []
for model_name in gpt_models:
try:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Successfully loaded {model_name} model")
models.append((model, tokenizer, model_name))
except Exception as e:
logger.error(f"Failed to load {model_name} model: {e}")
if not models:
raise HTTPException(status_code=500, detail="Failed to load any models")
return models
# Funci贸n para descargar modelos
async def download_models():
models = await load_models()
data_and_models_dict['models'] = models
@app.get('/')
async def main():
html_code = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ChatGPT Chatbot</title>
<style>
body, html {
height: 100%;
margin: 0;
padding: 0;
font-family: Arial, sans-serif;
}
.container {
height: 100%;
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
}
.chat-container {
border-radius: 10px;
overflow: hidden;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
width: 100%;
height: 100%;
}
.chat-box {
height: calc(100% - 60px);
overflow-y: auto;
padding: 10px;
}
.chat-input {
width: calc(100% - 100px);
padding: 10px;
border: none;
border-top: 1px solid #ccc;
font-size: 16px;
}
.input-container {
display: flex;
align-items: center;
justify-content: space-between;
padding: 10px;
background-color: #f5f5f5;
border-top: 1px solid #ccc;
width: 100%;
}
button {
padding: 10px;
border: none;
cursor: pointer;
background-color: #007bff;
color: #fff;
font-size: 16px;
}
.user-message {
background-color: #cce5ff;
border-radius: 5px;
align-self: flex-end;
max-width: 70%;
margin-left: auto;
margin-right: 10px;
margin-bottom: 10px;
}
.bot-message {
background-color: #d1ecf1;
border-radius: 5px;
align-self: flex-start;
max-width: 70%;
margin-bottom: 10px;
}
</style>
</head>
<body>
<div class="container">
<div class="chat-container">
<div class="chat-box" id="chat-box"></div>
<div class="input-container">
<input type="text" class="chat-input" id="user-input" placeholder="Escribe un mensaje...">
<button onclick="sendMessage()">Enviar</button>
</div>
</div>
</div>
<script>
const chatBox = document.getElementById('chat-box');
const userInput = document.getElementById('user-input');
function saveMessage(sender, message) {
const messageElement = document.createElement('div');
messageElement.textContent = `${sender}: ${message}`;
messageElement.classList.add(`${sender}-message`);
chatBox.appendChild(messageElement);
userInput.value = '';
}
async function sendMessage() {
const userMessage = userInput.value.trim();
if (!userMessage) return;
saveMessage('user', userMessage);
await fetch(`/autocomplete?q=${userMessage}`)
.then(response => response.json())
.then(data => {
saveMessage('bot', data.response);
chatBox.scrollTop = chatBox.scrollHeight;
})
.catch(error => console.error('Error:', error));
}
userInput.addEventListener("keyup", function(event) {
if (event.keyCode === 13) {
event.preventDefault();
sendMessage();
}
});
</script>
</body>
</html>
"""
return HTMLResponse(content=html_code, status_code=200)
# Ruta para la generaci贸n de respuestas
@app.get('/autocomplete')
async def autocomplete(q: str = Query(...)):
global data_and_models_dict, message_history, tokens_history
# Verificar si hay modelos cargados
if 'models' not in data_and_models_dict:
await download_models()
# Obtener los modelos
models = data_and_models_dict['models']
best_response = None
best_score = float('-inf') # Para almacenar la mejor puntuaci贸n
for model, tokenizer, model_name in models:
# Generar tokens de entrada
input_ids = tokenizer.encode(q, return_tensors="pt")
tokens_history.append({"input": input_ids.tolist()}) # Guardar tokens de entrada
# Generar par谩metros aleatorios
top_k = random.randint(0, 50)
top_p = random.uniform(0.8, 1.0)
temperature = random.uniform(0.7, 1.5)
# Generar una respuesta utilizando el modelo
output = model.generate(
input_ids,
max_length=50,
top_k=top_k,
top_p=top_p,
temperature=temperature,
num_return_sequences=1
)
response_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Calcular una puntuaci贸n simple para determinar la mejor respuesta
score = len(response_text) # Aqu铆 podr铆as usar otro criterio de puntuaci贸n
# Comparar y almacenar la mejor respuesta
if score > best_score:
best_score = score
best_response = response_text
# Generar tokens de salida
output_ids = output[0].tolist()
tokens_history.append({"output": output_ids}) # Guardar tokens de salida
# Guardar eos y pad tokens
eos_token = tokenizer.eos_token_id
pad_token = tokenizer.pad_token_id
tokens_history.append({"eos_token": eos_token, "pad_token": pad_token})
# Guardar el mensaje del usuario en el historial
message_history.append(q)
# Respuesta con la mejor respuesta generada
return {"response": best_response}
# Funci贸n para ejecutar la aplicaci贸n sin reiniciarla
def run_app():
asyncio.run(download_models())
uvicorn.run(app, host='0.0.0.0', port=7860)
# Ejecutar la aplicaci贸n
if __name__ == "__main__":
run_app()