import os import sys import torch import uvicorn import redis import numpy as np from fastapi import FastAPI, Query, BackgroundTasks from fastapi.responses import HTMLResponse from starlette.middleware.cors import CORSMiddleware from datasets import load_dataset from transformers import AutoTokenizer, GPT2LMHeadModel, pipeline from loguru import logger from dotenv import load_dotenv from sklearn.metrics.pairwise import cosine_similarity sys.path.append('..') load_dotenv() huggingface_token = os.getenv('HUGGINGFACE_TOKEN') kaggle_username = os.getenv('KAGGLE_USERNAME') kaggle_key = os.getenv('KAGGLE_KEY') redis_host = os.getenv('REDIS_HOST', 'localhost') redis_port = os.getenv('REDIS_PORT', 6379) redis_password = os.getenv('REDIS_PASSWORD', 'huggingface_spaces') redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password, decode_responses=True) MAX_ITEMS_PER_TABLE = 10000 def get_current_table_index(): return int(redis_client.get("current_table_index") or 0) def increment_table_index(): current_index = get_current_table_index() redis_client.set("current_table_index", current_index + 1) def store_to_redis_table(key, content): current_index = get_current_table_index() table_name = f"table_{current_index}" item_count = redis_client.hlen(table_name) if item_count >= MAX_ITEMS_PER_TABLE: increment_table_index() table_name = f"table_{get_current_table_index()}" redis_client.hset(table_name, key, content) def load_and_store_models(model_names): for name in model_names: try: model = GPT2LMHeadModel.from_pretrained(name) tokenizer = AutoTokenizer.from_pretrained(name) sample_text = "Sample input" generated_text = model.generate(tokenizer.encode(sample_text, return_tensors="pt"), max_length=50) decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True) store_to_redis_table(name, decoded_text) redis_client.hset("models", name, decoded_text) except Exception as e: logger.error(f"Error loading model {name}: {e}") app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) message_history = [] @app.get('/') async def index(): chat_history = redis_client.hgetall(f"table_{get_current_table_index()}") chat_history_html = "".join(f"
" for msg in chat_history.values()) html_code = f"""