|
import os |
|
import sys |
|
import torch |
|
import uvicorn |
|
import redis |
|
import numpy as np |
|
from fastapi import FastAPI, Query, BackgroundTasks |
|
from fastapi.responses import HTMLResponse |
|
from starlette.middleware.cors import CORSMiddleware |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, GPT2LMHeadModel, pipeline |
|
from loguru import logger |
|
from dotenv import load_dotenv |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
sys.path.append('..') |
|
|
|
load_dotenv() |
|
|
|
huggingface_token = os.getenv('HUGGINGFACE_TOKEN') |
|
kaggle_username = os.getenv('KAGGLE_USERNAME') |
|
kaggle_key = os.getenv('KAGGLE_KEY') |
|
|
|
redis_host = os.getenv('REDIS_HOST', 'localhost') |
|
redis_port = os.getenv('REDIS_PORT', 6379) |
|
redis_password = os.getenv('REDIS_PASSWORD', 'huggingface_spaces') |
|
redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password, decode_responses=True) |
|
|
|
MAX_ITEMS_PER_TABLE = 10000 |
|
|
|
def get_current_table_index(): |
|
return int(redis_client.get("current_table_index") or 0) |
|
|
|
def increment_table_index(): |
|
current_index = get_current_table_index() |
|
redis_client.set("current_table_index", current_index + 1) |
|
|
|
def store_to_redis_table(key, content): |
|
current_index = get_current_table_index() |
|
table_name = f"table_{current_index}" |
|
item_count = redis_client.hlen(table_name) |
|
if item_count >= MAX_ITEMS_PER_TABLE: |
|
increment_table_index() |
|
table_name = f"table_{get_current_table_index()}" |
|
redis_client.hset(table_name, key, content) |
|
|
|
def load_and_store_models(model_names): |
|
for name in model_names: |
|
try: |
|
model = GPT2LMHeadModel.from_pretrained(name) |
|
tokenizer = AutoTokenizer.from_pretrained(name) |
|
sample_text = "Sample input" |
|
generated_text = model.generate(tokenizer.encode(sample_text, return_tensors="pt"), max_length=50) |
|
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True) |
|
store_to_redis_table(name, decoded_text) |
|
redis_client.hset("models", name, decoded_text) |
|
except Exception as e: |
|
logger.error(f"Error loading model {name}: {e}") |
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
|
|
message_history = [] |
|
|
|
@app.get('/') |
|
async def index(): |
|
chat_history = redis_client.hgetall(f"table_{get_current_table_index()}") |
|
chat_history_html = "".join(f"<div class='bot-message'>{msg}</div>" for msg in chat_history.values()) |
|
|
|
html_code = f""" |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>ChatGPT Chatbot</title> |
|
<style> |
|
body {{ font-family: Arial, sans-serif; margin: 0; padding: 0; background-color: #f4f4f4; }} |
|
.container {{ max-width: 800px; margin: auto; padding: 20px; }} |
|
.chat-container {{ background-color: #fff; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); overflow: hidden; margin-bottom: 20px; }} |
|
.chat-box {{ height: 300px; overflow-y: auto; padding: 10px; }} |
|
.chat-input {{ width: calc(100% - 20px); border: none; border-top: 1px solid #ddd; padding: 10px; font-size: 16px; outline: none; }} |
|
.user-message, .bot-message {{ margin-bottom: 10px; padding: 8px 12px; border-radius: 8px; max-width: 70%; word-wrap: break-word; }} |
|
.user-message {{ background-color: #007bff; color: #fff; align-self: flex-end; }} |
|
.bot-message {{ background-color: #4CAF50; color: #fff; }} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1 style="text-align: center;">ChatGPT Chatbot</h1> |
|
<div class="chat-container" id="chat-container"> |
|
<div class="chat-box" id="chat-box"> |
|
{chat_history_html} |
|
</div> |
|
<input type="text" class="chat-input" id="user-input" placeholder="Type your message..."> |
|
</div> |
|
</div> |
|
<script> |
|
const userInput = document.getElementById('user-input'); |
|
|
|
userInput.addEventListener('keyup', function(event) {{ |
|
if (event.key === 'Enter') {{ |
|
event.preventDefault(); |
|
sendMessage(); |
|
}} |
|
}}); |
|
|
|
function sendMessage() {{ |
|
const userMessage = userInput.value.trim(); |
|
if (userMessage === '') return; |
|
|
|
appendMessage('user', userMessage); |
|
userInput.value = ''; |
|
|
|
fetch(`/autocomplete?q=` + encodeURIComponent(userMessage)) |
|
.then(response => response.json()) |
|
.then(data => {{ |
|
fetch(`/get_response?q=` + encodeURIComponent(userMessage)) |
|
.then(response => response.json()) |
|
.then(data => {{ |
|
const botMessage = data.response; |
|
appendMessage('bot', botMessage); |
|
}}) |
|
.catch(error => {{ |
|
console.error('Error:', error); |
|
}}); |
|
}}) |
|
.catch(error => {{ |
|
console.error('Error:', error); |
|
}}); |
|
}} |
|
|
|
function appendMessage(sender, message) {{ |
|
const chatBox = document.getElementById('chat-box'); |
|
const messageElement = document.createElement('div'); |
|
messageElement.className = sender + '-message'; |
|
messageElement.innerText = message; |
|
chatBox.appendChild(messageElement); |
|
}} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code, status_code=200) |
|
|
|
def calculate_similarity(base_text, candidate_texts): |
|
base_vector = np.array([len(base_text)]) |
|
similarities = [] |
|
for text in candidate_texts: |
|
candidate_vector = np.array([len(text)]) |
|
similarity = cosine_similarity([base_vector], [candidate_vector]) |
|
similarities.append(similarity[0][0]) |
|
return similarities |
|
|
|
@app.get('/autocomplete') |
|
async def autocomplete(q: str = Query(..., title='query'), background_tasks: BackgroundTasks): |
|
global message_history |
|
message_history.append(('user', q)) |
|
|
|
background_tasks.add_task(generate_responses, q) |
|
return {"status": "Processing request, please wait..."} |
|
|
|
@app.get('/get_response') |
|
async def get_response(q: str = Query(..., title='query')): |
|
response = redis_client.hget("responses", q) |
|
return {"response": response} |
|
|
|
def generate_responses(q): |
|
generated_responses = [] |
|
try: |
|
for model_name in redis_client.hkeys("models"): |
|
try: |
|
model_data = redis_client.hget("models", model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1) |
|
generated_response = text_generation_pipeline(q, do_sample=True, max_length=50, num_return_sequences=5) |
|
generated_responses.extend([response['generated_text'] for response in generated_response]) |
|
except Exception as e: |
|
logger.error(f"Error generating response with model {model_name}: {e}") |
|
|
|
if generated_responses: |
|
similarities = calculate_similarity(q, generated_responses) |
|
most_coherent_response = generated_responses[np.argmax(similarities)] |
|
store_to_redis_table(q, "\n".join(generated_responses)) |
|
redis_client.hset("responses", q, most_coherent_response) |
|
else: |
|
logger.warning("No valid responses generated.") |
|
except Exception as e: |
|
logger.error(f"General error in autocomplete: {e}") |
|
|
|
if __name__ == '__main__': |
|
gpt2_models = [ |
|
"gpt2", |
|
"gpt2-medium", |
|
"gpt2-large", |
|
"gpt2-xl" |
|
] |
|
|
|
programming_models = [ |
|
"google/bert2bert_L-24_uncased", |
|
"microsoft/CodeGPT-small-java", |
|
"microsoft/CodeGPT-small-python", |
|
"Salesforce/codegen-350M-multi" |
|
] |
|
|
|
load_and_store_models(gpt2_models + programming_models) |
|
|
|
uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 8001))) |