import os import sys import torch import uvicorn import redis import numpy as np from fastapi import FastAPI, Query, BackgroundTasks from fastapi.responses import HTMLResponse from starlette.middleware.cors import CORSMiddleware from datasets import load_dataset from transformers import AutoTokenizer, GPT2LMHeadModel, pipeline from loguru import logger from dotenv import load_dotenv from sklearn.metrics.pairwise import cosine_similarity sys.path.append('..') load_dotenv() huggingface_token = os.getenv('HUGGINGFACE_TOKEN') kaggle_username = os.getenv('KAGGLE_USERNAME') kaggle_key = os.getenv('KAGGLE_KEY') redis_host = os.getenv('REDIS_HOST', 'localhost') redis_port = os.getenv('REDIS_PORT', 6379) redis_password = os.getenv('REDIS_PASSWORD', 'huggingface_spaces') redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password, decode_responses=True) MAX_ITEMS_PER_TABLE = 10000 def get_current_table_index(): return int(redis_client.get("current_table_index") or 0) def increment_table_index(): current_index = get_current_table_index() redis_client.set("current_table_index", current_index + 1) def store_to_redis_table(key, content): current_index = get_current_table_index() table_name = f"table_{current_index}" item_count = redis_client.hlen(table_name) if item_count >= MAX_ITEMS_PER_TABLE: increment_table_index() table_name = f"table_{get_current_table_index()}" redis_client.hset(table_name, key, content) def load_and_store_models(model_names): for name in model_names: try: model = GPT2LMHeadModel.from_pretrained(name) tokenizer = AutoTokenizer.from_pretrained(name) sample_text = "Sample input" generated_text = model.generate(tokenizer.encode(sample_text, return_tensors="pt"), max_length=50) decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True) store_to_redis_table(name, decoded_text) redis_client.hset("models", name, decoded_text) except Exception as e: logger.error(f"Error loading model {name}: {e}") app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) message_history = [] @app.get('/') async def index(): chat_history = redis_client.hgetall(f"table_{get_current_table_index()}") chat_history_html = "".join(f"
{msg}
" for msg in chat_history.values()) html_code = f""" ChatGPT Chatbot

ChatGPT Chatbot

{chat_history_html}
""" return HTMLResponse(content=html_code, status_code=200) def calculate_similarity(base_text, candidate_texts): base_vector = np.array([len(base_text)]) similarities = [] for text in candidate_texts: candidate_vector = np.array([len(text)]) similarity = cosine_similarity([base_vector], [candidate_vector]) similarities.append(similarity[0][0]) return similarities @app.get('/autocomplete') async def autocomplete(q: str = Query(..., title='query'), background_tasks: BackgroundTasks): global message_history message_history.append(('user', q)) background_tasks.add_task(generate_responses, q) return {"status": "Processing request, please wait..."} @app.get('/get_response') async def get_response(q: str = Query(..., title='query')): response = redis_client.hget("responses", q) return {"response": response} def generate_responses(q): generated_responses = [] try: for model_name in redis_client.hkeys("models"): try: model_data = redis_client.hget("models", model_name) model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1) generated_response = text_generation_pipeline(q, do_sample=True, max_length=50, num_return_sequences=5) generated_responses.extend([response['generated_text'] for response in generated_response]) except Exception as e: logger.error(f"Error generating response with model {model_name}: {e}") if generated_responses: similarities = calculate_similarity(q, generated_responses) most_coherent_response = generated_responses[np.argmax(similarities)] store_to_redis_table(q, "\n".join(generated_responses)) redis_client.hset("responses", q, most_coherent_response) else: logger.warning("No valid responses generated.") except Exception as e: logger.error(f"General error in autocomplete: {e}") if __name__ == '__main__': gpt2_models = [ "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl" ] programming_models = [ "google/bert2bert_L-24_uncased", "microsoft/CodeGPT-small-java", "microsoft/CodeGPT-small-python", "Salesforce/codegen-350M-multi" ] load_and_store_models(gpt2_models + programming_models) uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 8001)))