Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ from fastapi.responses import HTMLResponse, StreamingResponse
|
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from modules.pmbl import PMBL
|
5 |
import torch
|
6 |
-
from queue import Queue
|
7 |
import asyncio
|
8 |
|
9 |
print(f"CUDA available: {torch.cuda.is_available()}")
|
@@ -17,7 +16,8 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
17 |
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
|
18 |
|
19 |
pmbl = PMBL("./PMB-7b.Q6_K.gguf", gpu_layers=50)
|
20 |
-
request_queue = Queue()
|
|
|
21 |
|
22 |
@app.head("/")
|
23 |
@app.get("/")
|
@@ -26,9 +26,10 @@ def index() -> HTMLResponse:
|
|
26 |
return HTMLResponse(content=f.read())
|
27 |
|
28 |
async def process_request(user_input: str, mode: str):
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
32 |
|
33 |
@app.post("/chat")
|
34 |
async def chat(request: Request, background_tasks: BackgroundTasks):
|
@@ -38,11 +39,8 @@ async def chat(request: Request, background_tasks: BackgroundTasks):
|
|
38 |
mode = data["mode"]
|
39 |
|
40 |
async def response_generator():
|
41 |
-
|
42 |
-
|
43 |
-
await future
|
44 |
-
|
45 |
-
async for chunk in future.result():
|
46 |
yield chunk
|
47 |
|
48 |
return StreamingResponse(response_generator(), media_type="text/plain")
|
@@ -52,11 +50,10 @@ async def chat(request: Request, background_tasks: BackgroundTasks):
|
|
52 |
|
53 |
async def queue_worker():
|
54 |
while True:
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
await asyncio.sleep(0.1)
|
60 |
|
61 |
@app.on_event("startup")
|
62 |
async def startup_event():
|
|
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from modules.pmbl import PMBL
|
5 |
import torch
|
|
|
6 |
import asyncio
|
7 |
|
8 |
print(f"CUDA available: {torch.cuda.is_available()}")
|
|
|
16 |
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
|
17 |
|
18 |
pmbl = PMBL("./PMB-7b.Q6_K.gguf", gpu_layers=50)
|
19 |
+
request_queue = asyncio.Queue()
|
20 |
+
processing_lock = asyncio.Lock()
|
21 |
|
22 |
@app.head("/")
|
23 |
@app.get("/")
|
|
|
26 |
return HTMLResponse(content=f.read())
|
27 |
|
28 |
async def process_request(user_input: str, mode: str):
|
29 |
+
async with processing_lock:
|
30 |
+
history = pmbl.get_chat_history(mode, user_input)
|
31 |
+
async for chunk in pmbl.generate_response(user_input, history, mode):
|
32 |
+
yield chunk
|
33 |
|
34 |
@app.post("/chat")
|
35 |
async def chat(request: Request, background_tasks: BackgroundTasks):
|
|
|
39 |
mode = data["mode"]
|
40 |
|
41 |
async def response_generator():
|
42 |
+
await request_queue.put((user_input, mode))
|
43 |
+
async for chunk in await process_request(user_input, mode):
|
|
|
|
|
|
|
44 |
yield chunk
|
45 |
|
46 |
return StreamingResponse(response_generator(), media_type="text/plain")
|
|
|
50 |
|
51 |
async def queue_worker():
|
52 |
while True:
|
53 |
+
user_input, mode = await request_queue.get()
|
54 |
+
async for _ in process_request(user_input, mode):
|
55 |
+
pass
|
56 |
+
request_queue.task_done()
|
|
|
57 |
|
58 |
@app.on_event("startup")
|
59 |
async def startup_event():
|