Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import asyncio | |
import uvicorn | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from fastapi import FastAPI, Query, HTTPException | |
from fastapi.responses import HTMLResponse | |
# Configuraci贸n de logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Inicializar la aplicaci贸n FastAPI | |
app = FastAPI() | |
# Diccionario para almacenar los modelos | |
data_and_models_dict = {} | |
# Lista para almacenar el historial de mensajes | |
message_history = [] | |
# Funci贸n para cargar modelos | |
async def load_models(): | |
gpt_models = ["gpt2-medium", "gpt2-large", "gpt2"] | |
for model_name in gpt_models: | |
try: | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logger.info(f"Successfully loaded {model_name} model") | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Failed to load GPT-2 model: {e}") | |
raise HTTPException(status_code=500, detail="Failed to load any models") | |
# Funci贸n para descargar modelos | |
async def download_models(): | |
model, tokenizer = await load_models() | |
data_and_models_dict['model'] = (model, tokenizer) | |
async def main(): | |
html_code = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>ChatGPT Chatbot</title> | |
<style> | |
body, html { | |
height: 100%; | |
margin: 0; | |
padding: 0; | |
font-family: Arial, sans-serif; | |
} | |
.container { | |
height: 100%; | |
display: flex; | |
flex-direction: column; | |
justify-content: center; | |
align-items: center; | |
} | |
.chat-container { | |
border-radius: 10px; | |
overflow: hidden; | |
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | |
width: 100%; | |
height: 100%; | |
} | |
.chat-box { | |
height: calc(100% - 60px); | |
overflow-y: auto; | |
padding: 10px; | |
} | |
.chat-input { | |
width: calc(100% - 100px); | |
padding: 10px; | |
border: none; | |
border-top: 1px solid #ccc; | |
font-size: 16px; | |
} | |
.input-container { | |
display: flex; | |
align-items: center; | |
justify-content: space-between; | |
padding: 10px; | |
background-color: #f5f5f5; | |
border-top: 1px solid #ccc; | |
width: 100%; | |
} | |
button { | |
padding: 10px; | |
border: none; | |
cursor: pointer; | |
background-color: #007bff; | |
color: #fff; | |
font-size: 16px; | |
} | |
.user-message { | |
background-color: #cce5ff; | |
border-radius: 5px; | |
align-self: flex-end; | |
max-width: 70%; | |
margin-left: auto; | |
margin-right: 10px; | |
margin-bottom: 10px; | |
} | |
.bot-message { | |
background-color: #d1ecf1; | |
border-radius: 5px; | |
align-self: flex-start; | |
max-width: 70%; | |
margin-bottom: 10px; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="chat-container"> | |
<div class="chat-box" id="chat-box"></div> | |
<div class="input-container"> | |
<input type="text" class="chat-input" id="user-input" placeholder="Escribe un mensaje..."> | |
<button onclick="sendMessage()">Enviar</button> | |
</div> | |
</div> | |
</div> | |
<script> | |
const chatBox = document.getElementById('chat-box'); | |
const userInput = document.getElementById('user-input'); | |
function saveMessage(sender, message) { | |
const messageElement = document.createElement('div'); | |
messageElement.textContent = `${sender}: ${message}`; | |
messageElement.classList.add(`${sender}-message`); | |
chatBox.appendChild(messageElement); | |
userInput.value = ''; | |
} | |
async function sendMessage() { | |
const userMessage = userInput.value.trim(); | |
if (!userMessage) return; | |
saveMessage('user', userMessage); | |
await fetch(`/autocomplete?q=${userMessage}`) | |
.then(response => response.text()) | |
.then(data => { | |
saveMessage('bot', data); | |
chatBox.scrollTop = chatBox.scrollHeight; | |
}) | |
.catch(error => console.error('Error:', error)); | |
} | |
userInput.addEventListener("keyup", function(event) { | |
if (event.keyCode === 13) { | |
event.preventDefault(); | |
sendMessage(); | |
} | |
}); | |
</script> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_code, status_code=200) | |
# Ruta para la generaci贸n de respuestas | |
async def autocomplete(q: str = Query(...)): | |
global data_and_models_dict, message_history | |
# Verificar si hay modelos cargados | |
if 'model' not in data_and_models_dict: | |
await download_models() | |
# Obtener el modelo | |
model, tokenizer = data_and_models_dict['model'] | |
# Guardar el mensaje del usuario en el historial | |
message_history.append(q) | |
# Generar una respuesta utilizando el modelo | |
input_ids = tokenizer.encode(q, return_tensors="pt") | |
output = model.generate(input_ids, max_length=50, num_return_sequences=1) | |
response_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
# Guardar la respuesta en el historial | |
message_history.append(response_text) | |
return response_text | |
# Funci贸n para ejecutar la aplicaci贸n sin reiniciarla | |
def run_app(): | |
asyncio.run(download_models()) | |
uvicorn.run(app, host='0.0.0.0', port=7860) | |
# Ejecutar la aplicaci贸n | |
if __name__ == "__main__": | |
run_app() |