Gxhhfhdhdggggg / app.py
Yhhxhfh's picture
Update app.py
9baac46 verified
import os
import logging
import asyncio
import uvicorn
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from fastapi import FastAPI, Query, HTTPException
from fastapi.responses import HTMLResponse
# Configuraci贸n de logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Inicializar la aplicaci贸n FastAPI
app = FastAPI()
# Diccionario para almacenar los modelos
data_and_models_dict = {}
# Lista para almacenar el historial de mensajes
message_history = []
# Funci贸n para cargar modelos
async def load_models():
gpt_models = ["gpt2-medium", "gpt2-large", "gpt2"]
for model_name in gpt_models:
try:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Successfully loaded {model_name} model")
return model, tokenizer
except Exception as e:
logger.error(f"Failed to load GPT-2 model: {e}")
raise HTTPException(status_code=500, detail="Failed to load any models")
# Funci贸n para descargar modelos
async def download_models():
model, tokenizer = await load_models()
data_and_models_dict['model'] = (model, tokenizer)
@app.get('/')
async def main():
html_code = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ChatGPT Chatbot</title>
<style>
body, html {
height: 100%;
margin: 0;
padding: 0;
font-family: Arial, sans-serif;
}
.container {
height: 100%;
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
}
.chat-container {
border-radius: 10px;
overflow: hidden;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
width: 100%;
height: 100%;
}
.chat-box {
height: calc(100% - 60px);
overflow-y: auto;
padding: 10px;
}
.chat-input {
width: calc(100% - 100px);
padding: 10px;
border: none;
border-top: 1px solid #ccc;
font-size: 16px;
}
.input-container {
display: flex;
align-items: center;
justify-content: space-between;
padding: 10px;
background-color: #f5f5f5;
border-top: 1px solid #ccc;
width: 100%;
}
button {
padding: 10px;
border: none;
cursor: pointer;
background-color: #007bff;
color: #fff;
font-size: 16px;
}
.user-message {
background-color: #cce5ff;
border-radius: 5px;
align-self: flex-end;
max-width: 70%;
margin-left: auto;
margin-right: 10px;
margin-bottom: 10px;
}
.bot-message {
background-color: #d1ecf1;
border-radius: 5px;
align-self: flex-start;
max-width: 70%;
margin-bottom: 10px;
}
</style>
</head>
<body>
<div class="container">
<div class="chat-container">
<div class="chat-box" id="chat-box"></div>
<div class="input-container">
<input type="text" class="chat-input" id="user-input" placeholder="Escribe un mensaje...">
<button onclick="sendMessage()">Enviar</button>
</div>
</div>
</div>
<script>
const chatBox = document.getElementById('chat-box');
const userInput = document.getElementById('user-input');
function saveMessage(sender, message) {
const messageElement = document.createElement('div');
messageElement.textContent = `${sender}: ${message}`;
messageElement.classList.add(`${sender}-message`);
chatBox.appendChild(messageElement);
userInput.value = '';
}
async function sendMessage() {
const userMessage = userInput.value.trim();
if (!userMessage) return;
saveMessage('user', userMessage);
await fetch(`/autocomplete?q=${userMessage}`)
.then(response => response.text())
.then(data => {
saveMessage('bot', data);
chatBox.scrollTop = chatBox.scrollHeight;
})
.catch(error => console.error('Error:', error));
}
userInput.addEventListener("keyup", function(event) {
if (event.keyCode === 13) {
event.preventDefault();
sendMessage();
}
});
</script>
</body>
</html>
"""
return HTMLResponse(content=html_code, status_code=200)
# Ruta para la generaci贸n de respuestas
@app.get('/autocomplete')
async def autocomplete(q: str = Query(...)):
global data_and_models_dict, message_history
# Verificar si hay modelos cargados
if 'model' not in data_and_models_dict:
await download_models()
# Obtener el modelo
model, tokenizer = data_and_models_dict['model']
# Guardar el mensaje del usuario en el historial
message_history.append(q)
# Generar una respuesta utilizando el modelo
input_ids = tokenizer.encode(q, return_tensors="pt")
output = model.generate(input_ids, max_length=50, num_return_sequences=1)
response_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Guardar la respuesta en el historial
message_history.append(response_text)
return response_text
# Funci贸n para ejecutar la aplicaci贸n sin reiniciarla
def run_app():
asyncio.run(download_models())
uvicorn.run(app, host='0.0.0.0', port=7860)
# Ejecutar la aplicaci贸n
if __name__ == "__main__":
run_app()