import os import sys import torch import uvicorn import redis import numpy as np import random from fastapi import FastAPI, Query, BackgroundTasks from fastapi.responses import HTMLResponse from starlette.middleware.cors import CORSMiddleware from datasets import load_dataset from transformers import AutoTokenizer, GPT2LMHeadModel, pipeline from loguru import logger from dotenv import load_dotenv from sklearn.metrics.pairwise import cosine_similarity from kaggle.api.kaggle_api_extended import KaggleApi # Importar la librería de spaces import spaces sys.path.append('..') load_dotenv() huggingface_token = os.getenv('HUGGINGFACE_TOKEN') kaggle_username = os.getenv('KAGGLE_USERNAME') kaggle_key = os.getenv('KAGGLE_KEY') redis_host = os.getenv('REDIS_HOST', 'localhost') redis_port = os.getenv('REDIS_PORT', 6379) redis_password = os.getenv('REDIS_PASSWORD', 'huggingface_spaces') redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password, decode_responses=True) MAX_ITEMS_PER_TABLE = 10000 # Decorador para usar GPU en Spaces @spaces.GPU() def generate_responses_gpu(q): generated_responses = [] try: for model_name in redis_client.hkeys("models"): try: model_data = redis_client.hget("models", model_name) model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Generar valores aleatorios para top_p, top_k y temperature top_p = round(random.uniform(0.01, 0.99), 2) top_k = random.randint(1, 99) temperature = round(random.uniform(0.01, 1.99), 2) text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1) generated_response = text_generation_pipeline(q, do_sample=True, max_length=50, num_return_sequences=5, top_p=top_p, top_k=top_k, temperature=temperature) generated_responses.extend([response['generated_text'] for response in generated_response]) except Exception as e: logger.error(f"Error generating response with model {model_name}: {e}") if generated_responses: similarities = calculate_similarity(q, generated_responses) most_coherent_response = generated_responses[np.argmax(similarities)] store_to_redis_table(q, "\n".join(generated_responses)) redis_client.hset("responses", q, most_coherent_response) else: logger.warning("No valid responses generated.") except Exception as e: logger.error(f"General error in autocomplete: {e}") def get_current_table_index(): return int(redis_client.get("current_table_index") or 0) def increment_table_index(): current_index = get_current_table_index() redis_client.set("current_table_index", current_index + 1) def store_to_redis_table(key, content): current_index = get_current_table_index() table_name = f"table_{current_index}" item_count = redis_client.hlen(table_name) if item_count >= MAX_ITEMS_PER_TABLE: increment_table_index() table_name = f"table_{get_current_table_index()}" redis_client.hset(table_name, key, content) def load_and_store_models(model_names): for name in model_names: try: model = GPT2LMHeadModel.from_pretrained(name) tokenizer = AutoTokenizer.from_pretrained(name) sample_text = "Sample input" generated_text = model.generate(tokenizer.encode(sample_text, return_tensors="pt"), max_length=50) decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True) store_to_redis_table(name, decoded_text) redis_client.hset("models", name, decoded_text) except Exception as e: logger.error(f"Error loading model {name}: {e}") def load_kaggle_datasets(dataset_names): api = KaggleApi() api.authenticate() for dataset_name in dataset_names: try: api.dataset_download_files(dataset_name, path='./kaggle_datasets', unzip=True) dataset = load_dataset('csv', data_files=[f'./kaggle_datasets/{dataset_name}/*.csv'])['train'] sample_data = dataset.to_pandas().head(10).to_json(orient='records') store_to_redis_table(dataset_name, sample_data) redis_client.hset("kaggle_datasets", dataset_name, sample_data) except Exception as e: logger.error(f"Error loading Kaggle dataset {dataset_name}: {e}") app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) message_history = [] @app.get('/') async def index(): chat_history = redis_client.hgetall(f"table_{get_current_table_index()}") chat_history_html = "".join(f"
" for msg in chat_history.values()) html_code = f"""