|
import os |
|
import logging |
|
import asyncio |
|
import uvicorn |
|
import random |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from fastapi import FastAPI, Query, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
data_and_models_dict = {} |
|
|
|
|
|
message_history = [] |
|
|
|
|
|
tokens_history = [] |
|
|
|
|
|
async def load_models(): |
|
programming_models = [ |
|
"microsoft/CodeGPT-small-py", |
|
"Salesforce/codegen-350M-multi", |
|
"Salesforce/codegen-2B-multi" |
|
] |
|
gpt_models = ["gpt2-medium", "gpt2-large", "gpt2", "google/gemma-2-9b"] + programming_models |
|
|
|
models = [] |
|
for model_name in gpt_models: |
|
try: |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
logger.info(f"Successfully loaded {model_name} model") |
|
models.append((model, tokenizer, model_name)) |
|
except Exception as e: |
|
logger.error(f"Failed to load {model_name} model: {e}") |
|
if not models: |
|
raise HTTPException(status_code=500, detail="Failed to load any models") |
|
return models |
|
|
|
|
|
async def download_models(): |
|
models = await load_models() |
|
data_and_models_dict['models'] = models |
|
|
|
@app.get('/') |
|
async def main(): |
|
html_code = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>ChatGPT Chatbot</title> |
|
<style> |
|
body, html { |
|
height: 100%; |
|
margin: 0; |
|
padding: 0; |
|
font-family: Arial, sans-serif; |
|
} |
|
.container { |
|
height: 100%; |
|
display: flex; |
|
flex-direction: column; |
|
justify-content: center; |
|
align-items: center; |
|
} |
|
.chat-container { |
|
border-radius: 10px; |
|
overflow: hidden; |
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); |
|
width: 100%; |
|
height: 100%; |
|
} |
|
.chat-box { |
|
height: calc(100% - 60px); |
|
overflow-y: auto; |
|
padding: 10px; |
|
} |
|
.chat-input { |
|
width: calc(100% - 100px); |
|
padding: 10px; |
|
border: none; |
|
border-top: 1px solid #ccc; |
|
font-size: 16px; |
|
} |
|
.input-container { |
|
display: flex; |
|
align-items: center; |
|
justify-content: space-between; |
|
padding: 10px; |
|
background-color: #f5f5f5; |
|
border-top: 1px solid #ccc; |
|
width: 100%; |
|
} |
|
button { |
|
padding: 10px; |
|
border: none; |
|
cursor: pointer; |
|
background-color: #007bff; |
|
color: #fff; |
|
font-size: 16px; |
|
} |
|
.user-message { |
|
background-color: #cce5ff; |
|
border-radius: 5px; |
|
align-self: flex-end; |
|
max-width: 70%; |
|
margin-left: auto; |
|
margin-right: 10px; |
|
margin-bottom: 10px; |
|
} |
|
.bot-message { |
|
background-color: #d1ecf1; |
|
border-radius: 5px; |
|
align-self: flex-start; |
|
max-width: 70%; |
|
margin-bottom: 10px; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<div class="chat-container"> |
|
<div class="chat-box" id="chat-box"></div> |
|
<div class="input-container"> |
|
<input type="text" class="chat-input" id="user-input" placeholder="Escribe un mensaje..."> |
|
<button onclick="sendMessage()">Enviar</button> |
|
</div> |
|
</div> |
|
</div> |
|
<script> |
|
const chatBox = document.getElementById('chat-box'); |
|
const userInput = document.getElementById('user-input'); |
|
|
|
function saveMessage(sender, message) { |
|
const messageElement = document.createElement('div'); |
|
messageElement.textContent = `${sender}: ${message}`; |
|
messageElement.classList.add(`${sender}-message`); |
|
chatBox.appendChild(messageElement); |
|
userInput.value = ''; |
|
} |
|
|
|
async function sendMessage() { |
|
const userMessage = userInput.value.trim(); |
|
if (!userMessage) return; |
|
|
|
saveMessage('user', userMessage); |
|
await fetch(`/autocomplete?q=${userMessage}`) |
|
.then(response => response.json()) |
|
.then(data => { |
|
saveMessage('bot', data.response); |
|
chatBox.scrollTop = chatBox.scrollHeight; |
|
}) |
|
.catch(error => console.error('Error:', error)); |
|
} |
|
|
|
userInput.addEventListener("keyup", function(event) { |
|
if (event.keyCode === 13) { |
|
event.preventDefault(); |
|
sendMessage(); |
|
} |
|
}); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code, status_code=200) |
|
|
|
|
|
@app.get('/autocomplete') |
|
async def autocomplete(q: str = Query(...)): |
|
global data_and_models_dict, message_history, tokens_history |
|
|
|
|
|
if 'models' not in data_and_models_dict: |
|
await download_models() |
|
|
|
|
|
models = data_and_models_dict['models'] |
|
|
|
best_response = None |
|
best_score = float('-inf') |
|
|
|
for model, tokenizer, model_name in models: |
|
|
|
input_ids = tokenizer.encode(q, return_tensors="pt") |
|
tokens_history.append({"input": input_ids.tolist()}) |
|
|
|
|
|
top_k = random.randint(0, 50) |
|
top_p = random.uniform(0.8, 1.0) |
|
temperature = random.uniform(0.7, 1.5) |
|
|
|
|
|
output = model.generate( |
|
input_ids, |
|
max_length=50, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temperature=temperature, |
|
num_return_sequences=1 |
|
) |
|
response_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
score = len(response_text) |
|
|
|
|
|
if score > best_score: |
|
best_score = score |
|
best_response = response_text |
|
|
|
|
|
output_ids = output[0].tolist() |
|
tokens_history.append({"output": output_ids}) |
|
|
|
|
|
eos_token = tokenizer.eos_token_id |
|
pad_token = tokenizer.pad_token_id |
|
tokens_history.append({"eos_token": eos_token, "pad_token": pad_token}) |
|
|
|
|
|
message_history.append(q) |
|
|
|
|
|
return {"response": best_response} |
|
|
|
|
|
def run_app(): |
|
asyncio.run(download_models()) |
|
uvicorn.run(app, host='0.0.0.0', port=7860) |
|
|
|
|
|
if __name__ == "__main__": |
|
run_app() |
|
|