Sergidev commited on
Commit
487e8b7
1 Parent(s): 9370091

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, Request
2
  from fastapi.responses import HTMLResponse, StreamingResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from modules.pmbl import PMBL
@@ -14,7 +14,12 @@ app = FastAPI(docs_url=None, redoc_url=None)
14
  app.mount("/static", StaticFiles(directory="static"), name="static")
15
  app.mount("/templates", StaticFiles(directory="templates"), name="templates")
16
 
17
- pmbl = PMBL("./PMB-7b.Q6_K.gguf", gpu_layers=50)
 
 
 
 
 
18
 
19
  @app.head("/")
20
  @app.get("/")
@@ -23,12 +28,12 @@ def index() -> HTMLResponse:
23
  return HTMLResponse(content=f.read())
24
 
25
  @app.post("/chat")
26
- async def chat(request: Request):
27
  try:
28
  data = await request.json()
29
  user_input = data["user_input"]
30
  mode = data["mode"]
31
- history = pmbl.get_chat_history(mode, user_input)
32
  response_generator = pmbl.generate_response(user_input, history, mode)
33
  return StreamingResponse(response_generator, media_type="text/plain")
34
  except Exception as e:
@@ -36,9 +41,9 @@ async def chat(request: Request):
36
  return {"error": str(e)}
37
 
38
  @app.post("/sleep")
39
- async def sleep():
40
  try:
41
- pmbl.sleep_mode()
42
  return {"message": "Sleep mode completed successfully"}
43
  except Exception as e:
44
  print(f"[SYSTEM] Error: {str(e)}")
 
1
+ from fastapi import FastAPI, Request, Depends
2
  from fastapi.responses import HTMLResponse, StreamingResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from modules.pmbl import PMBL
 
14
  app.mount("/static", StaticFiles(directory="static"), name="static")
15
  app.mount("/templates", StaticFiles(directory="templates"), name="templates")
16
 
17
+ async def get_pmbl():
18
+ pmbl = PMBL("./PMB-7b.Q6_K.gguf", gpu_layers=50)
19
+ try:
20
+ yield pmbl
21
+ finally:
22
+ await pmbl.close()
23
 
24
  @app.head("/")
25
  @app.get("/")
 
28
  return HTMLResponse(content=f.read())
29
 
30
  @app.post("/chat")
31
+ async def chat(request: Request, pmbl: PMBL = Depends(get_pmbl)):
32
  try:
33
  data = await request.json()
34
  user_input = data["user_input"]
35
  mode = data["mode"]
36
+ history = await pmbl.get_chat_history(mode, user_input)
37
  response_generator = pmbl.generate_response(user_input, history, mode)
38
  return StreamingResponse(response_generator, media_type="text/plain")
39
  except Exception as e:
 
41
  return {"error": str(e)}
42
 
43
  @app.post("/sleep")
44
+ async def sleep(pmbl: PMBL = Depends(get_pmbl)):
45
  try:
46
+ await pmbl.sleep_mode()
47
  return {"message": "Sleep mode completed successfully"}
48
  except Exception as e:
49
  print(f"[SYSTEM] Error: {str(e)}")