import os import logging import asyncio import uvicorn import torch from transformers import AutoModelForCausalLM, AutoTokenizer from fastapi import FastAPI, Query, HTTPException from fastapi.responses import HTMLResponse # Configuración de logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # Inicializar la aplicación FastAPI app = FastAPI() # Diccionario para almacenar los modelos data_and_models_dict = {} # Lista para almacenar el historial de mensajes message_history = [] # Función para cargar modelos async def load_models(): gpt_models = ["gpt2-medium", "gpt2-large", "gpt2"] for model_name in gpt_models: try: model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info(f"Successfully loaded {model_name} model") return model, tokenizer except Exception as e: logger.error(f"Failed to load GPT-2 model: {e}") raise HTTPException(status_code=500, detail="Failed to load any models") # Función para descargar modelos async def download_models(): model, tokenizer = await load_models() data_and_models_dict['model'] = (model, tokenizer) @app.get('/') async def main(): html_code = """ ChatGPT Chatbot
""" return HTMLResponse(content=html_code, status_code=200) # Ruta para la generación de respuestas @app.get('/autocomplete') async def autocomplete(q: str = Query(...)): global data_and_models_dict, message_history # Verificar si hay modelos cargados if 'model' not in data_and_models_dict: await download_models() # Obtener el modelo model, tokenizer = data_and_models_dict['model'] # Guardar el mensaje del usuario en el historial message_history.append(q) # Generar una respuesta utilizando el modelo input_ids = tokenizer.encode(q, return_tensors="pt") output = model.generate(input_ids, max_length=50, num_return_sequences=1) response_text = tokenizer.decode(output[0], skip_special_tokens=True) # Guardar la respuesta en el historial message_history.append(response_text) return response_text # Función para ejecutar la aplicación sin reiniciarla def run_app(): asyncio.run(download_models()) uvicorn.run(app, host='0.0.0.0', port=7860) # Ejecutar la aplicación if __name__ == "__main__": run_app()