PMAlpha / app.py
Sergidev's picture
Update app.py
487e8b7 verified
raw
history blame
No virus
1.69 kB
from fastapi import FastAPI, Request, Depends
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from modules.pmbl import PMBL
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
async def get_pmbl():
pmbl = PMBL("./PMB-7b.Q6_K.gguf", gpu_layers=50)
try:
yield pmbl
finally:
await pmbl.close()
@app.head("/")
@app.get("/")
def index() -> HTMLResponse:
with open("templates/index.html") as f:
return HTMLResponse(content=f.read())
@app.post("/chat")
async def chat(request: Request, pmbl: PMBL = Depends(get_pmbl)):
try:
data = await request.json()
user_input = data["user_input"]
mode = data["mode"]
history = await pmbl.get_chat_history(mode, user_input)
response_generator = pmbl.generate_response(user_input, history, mode)
return StreamingResponse(response_generator, media_type="text/plain")
except Exception as e:
print(f"[SYSTEM] Error: {str(e)}")
return {"error": str(e)}
@app.post("/sleep")
async def sleep(pmbl: PMBL = Depends(get_pmbl)):
try:
await pmbl.sleep_mode()
return {"message": "Sleep mode completed successfully"}
except Exception as e:
print(f"[SYSTEM] Error: {str(e)}")
return {"error": str(e)}