import os
import sys
import torch
import uvicorn
from fastapi import FastAPI, Query
from fastapi.responses import HTMLResponse
from starlette.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, pipeline, GPT2LMHeadModel
from loguru import logger
sys.path.append('..')
# Use finetuned GPT model
current_dir = os.path.dirname(os.path.realpath(__file__))
text_file_path = os.path.join(current_dir, 'xfa.txt')
with open(text_file_path, 'r') as file:
model_names = [line.strip() for line in file.readlines()]
models_dict = {}
# Detect and load necessary models
for name in model_names:
try:
model = GPT2LMHeadModel.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name)
models_dict[name] = {
'model': model,
'tokenizer': tokenizer
}
except Exception as e:
logger.error(f"Error loading model {name}: {e}")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# Global variable to store the messages history
message_history = []
@app.get('/')
async def index():
html_code = """
ChatGPT Chatbot
ChatGPT Chatbot
Toggle History
"""
return HTMLResponse(content=html_code, status_code=200)
@app.get('/autocomplete')
async def autocomplete(q: str = Query(..., title='query')):
global message_history
message_history.append(('user', q))
try:
# Use combined models for responses
generated_responses = []
for model_name, model_info in models_dict.items():
model = model_info['model']
tokenizer = model_info['tokenizer']
text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
generated_response = text_generation_pipeline(q, do_sample=True, num_return_sequences=5)
generated_responses.extend([response['generated_text'] for response in generated_response])
message_history.extend([('bot', response['generated_text']) for response in generated_response])
logger.debug(f"Successfully autocomplete, q:{q}, res:{generated_responses}")
# Find the response closest to the question
closest_response = min(generated_responses, key=lambda x: abs(len(x) - len(q)))
return {"result": [closest_response]}
except Exception as e:
logger.error(f"Ignored error in autocomplete: {e}")
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 8001)))