import os import sys import torch import uvicorn from fastapi import FastAPI, Query from fastapi.responses import HTMLResponse from starlette.middleware.cors import CORSMiddleware from transformers import AutoTokenizer, pipeline, GPT2LMHeadModel from loguru import logger sys.path.append('..') # Use finetuned GPT model current_dir = os.path.dirname(os.path.realpath(__file__)) text_file_path = os.path.join(current_dir, 'xfa.txt') with open(text_file_path, 'r') as file: model_names = [line.strip() for line in file.readlines()] models_dict = {} # Detect and load necessary models for name in model_names: try: model = GPT2LMHeadModel.from_pretrained(name) tokenizer = AutoTokenizer.from_pretrained(name) models_dict[name] = { 'model': model, 'tokenizer': tokenizer } except Exception as e: logger.error(f"Error loading model {name}: {e}") app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) # Global variable to store the messages history message_history = [] @app.get('/') async def index(): html_code = """ ChatGPT Chatbot

ChatGPT Chatbot

Toggle History

Chat History

""" return HTMLResponse(content=html_code, status_code=200) @app.get('/autocomplete') async def autocomplete(q: str = Query(..., title='query')): global message_history message_history.append(('user', q)) try: # Use combined models for responses generated_responses = [] for model_name, model_info in models_dict.items(): model = model_info['model'] tokenizer = model_info['tokenizer'] text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1) generated_response = text_generation_pipeline(q, do_sample=True, num_return_sequences=5) generated_responses.extend([response['generated_text'] for response in generated_response]) message_history.extend([('bot', response['generated_text']) for response in generated_response]) logger.debug(f"Successfully autocomplete, q:{q}, res:{generated_responses}") # Find the response closest to the question closest_response = min(generated_responses, key=lambda x: abs(len(x) - len(q))) return {"result": [closest_response]} except Exception as e: logger.error(f"Ignored error in autocomplete: {e}") if __name__ == '__main__': uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 8001)))