PMAlpha / app.py
Sergidev's picture
Update app.py
2f9891f verified
raw
history blame
2.29 kB
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from modules.pmbl import PMBL
import torch
from queue import Queue
import asyncio
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
pmbl = PMBL("./PMB-7b.Q6_K.gguf", gpu_layers=50)
request_queue = Queue()
@app.head("/")
@app.get("/")
def index() -> HTMLResponse:
with open("templates/index.html") as f:
return HTMLResponse(content=f.read())
async def process_request(user_input: str, mode: str):
history = pmbl.get_chat_history(mode, user_input)
async for chunk in pmbl.generate_response(user_input, history, mode):
yield chunk
@app.post("/chat")
async def chat(request: Request, background_tasks: BackgroundTasks):
try:
data = await request.json()
user_input = data["user_input"]
mode = data["mode"]
async def response_generator():
future = asyncio.Future()
request_queue.put((future, user_input, mode))
await future
async for chunk in future.result():
yield chunk
return StreamingResponse(response_generator(), media_type="text/plain")
except Exception as e:
print(f"[SYSTEM] Error: {str(e)}")
return {"error": str(e)}
async def queue_worker():
while True:
if not request_queue.empty():
future, user_input, mode = request_queue.get()
result = process_request(user_input, mode)
future.set_result(result)
await asyncio.sleep(0.1)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(queue_worker())
@app.post("/sleep")
async def sleep():
try:
pmbl.sleep_mode()
return {"message": "Sleep mode completed successfully"}
except Exception as e:
print(f"[SYSTEM] Error: {str(e)}")
return {"error": str(e)}