Spaces:
Sleeping
Sleeping
Implemented async for performance gain
Browse files- app/api/userchat.py +1 -1
- app/main.py +4 -3
- app/utils/chat_rag.py +37 -34
app/api/userchat.py
CHANGED
@@ -11,7 +11,7 @@ async def chat_with_llama(user_input: str = Body(..., embed=True), current_user:
|
|
11 |
# Example logic for model inference (pseudo-code, adjust as necessary)
|
12 |
try:
|
13 |
user_id = current_user["user_id"]
|
14 |
-
model_response = llm_infer(user_collection_name=sanitize_collection_name(user_id), prompt=user_input)
|
15 |
# Optionally, store chat history
|
16 |
# chromadb_face_helper.store_chat_history(user_id=current_user["user_id"], user_input=user_input, model_response=model_response)
|
17 |
except Exception as e:
|
|
|
11 |
# Example logic for model inference (pseudo-code, adjust as necessary)
|
12 |
try:
|
13 |
user_id = current_user["user_id"]
|
14 |
+
model_response = await llm_infer(user_collection_name=sanitize_collection_name(user_id), prompt=user_input)
|
15 |
# Optionally, store chat history
|
16 |
# chromadb_face_helper.store_chat_history(user_id=current_user["user_id"], user_input=user_input, model_response=model_response)
|
17 |
except Exception as e:
|
app/main.py
CHANGED
@@ -11,7 +11,7 @@ from admin import admin_functions as admin
|
|
11 |
from utils.db import UserFaceEmbeddingFunction,ChromaDBFaceHelper
|
12 |
from api import userlogin, userlogout, userchat, userupload
|
13 |
from utils.db import ChromaDBFaceHelper
|
14 |
-
from utils.chat_rag import
|
15 |
|
16 |
app = FastAPI()
|
17 |
|
@@ -42,8 +42,9 @@ async def startup_event():
|
|
42 |
chromadb_face_helper = ChromaDBFaceHelper(db_path) # Used by APIs
|
43 |
|
44 |
# Perform any other startup tasks here
|
45 |
-
#
|
46 |
-
|
|
|
47 |
|
48 |
print(f"MODEL_PATH in main.py = {os.getenv('MODEL_PATH')} ")
|
49 |
|
|
|
11 |
from utils.db import UserFaceEmbeddingFunction,ChromaDBFaceHelper
|
12 |
from api import userlogin, userlogout, userchat, userupload
|
13 |
from utils.db import ChromaDBFaceHelper
|
14 |
+
from utils.chat_rag import LlamaModelSingleton
|
15 |
|
16 |
app = FastAPI()
|
17 |
|
|
|
42 |
chromadb_face_helper = ChromaDBFaceHelper(db_path) # Used by APIs
|
43 |
|
44 |
# Perform any other startup tasks here
|
45 |
+
# Preload the LLM model
|
46 |
+
await LlamaModelSingleton.get_instance()
|
47 |
+
print("LLM model loaded and ready.")
|
48 |
|
49 |
print(f"MODEL_PATH in main.py = {os.getenv('MODEL_PATH')} ")
|
50 |
|
app/utils/chat_rag.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
import os
|
3 |
import re
|
4 |
import hashlib
|
|
|
5 |
|
6 |
from langchain.document_loaders import PyPDFLoader
|
7 |
|
@@ -47,7 +48,7 @@ def sanitize_collection_name(email):
|
|
47 |
|
48 |
|
49 |
# Modify vectordb initialization to be dynamic based on user_id
|
50 |
-
def get_vectordb_for_user(user_collection_name):
|
51 |
# Get Chromadb location
|
52 |
CHROMADB_LOC = os.getenv('CHROMADB_LOC')
|
53 |
|
@@ -60,9 +61,9 @@ def get_vectordb_for_user(user_collection_name):
|
|
60 |
|
61 |
vectordb_cache = {}
|
62 |
|
63 |
-
def get_vectordb_for_user_cached(user_collection_name):
|
64 |
if user_collection_name not in vectordb_cache:
|
65 |
-
vectordb_cache[user_collection_name] = get_vectordb_for_user(user_collection_name)
|
66 |
return vectordb_cache[user_collection_name]
|
67 |
|
68 |
|
@@ -93,42 +94,44 @@ def pdf_to_vec(filename, user_collection_name):
|
|
93 |
return(vectordb)
|
94 |
#return collection # Return the collection as the asset
|
95 |
|
|
|
|
|
96 |
class LlamaModelSingleton:
|
97 |
_instance = None
|
98 |
|
99 |
-
|
|
|
100 |
if cls._instance is None:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
return LlamaModelSingleton()
|
123 |
|
124 |
|
125 |
|
126 |
#step 5, to instantiate once to create default_chain,router_chain,destination_chains into chain and set vectordb. so will not re-create per prompt
|
127 |
-
def default_chain(llm, user_collection_name):
|
128 |
# Get Chromadb location
|
129 |
CHROMADB_LOC = os.getenv('CHROMADB_LOC')
|
130 |
|
131 |
-
vectordb = get_vectordb_for_user_cached(user_collection_name) # Use the dynamic vectordb based on user_id
|
132 |
sum_template = """
|
133 |
As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
|
134 |
|
@@ -209,13 +212,13 @@ def default_chain(llm, user_collection_name):
|
|
209 |
return default_chain,router_chain,destination_chains
|
210 |
|
211 |
# Adjust llm_infer to accept user_id and use it for user-specific processing
|
212 |
-
def llm_infer(user_collection_name, prompt):
|
213 |
|
214 |
-
llm = load_llm() # load_llm is singleton for entire system
|
215 |
|
216 |
-
vectordb = get_vectordb_for_user_cached(user_collection_name) # Vector collection for each us.
|
217 |
|
218 |
-
default_chain, router_chain, destination_chains = get_or_create_chain(user_collection_name, llm) # Now user-specific
|
219 |
|
220 |
chain = MultiPromptChain(
|
221 |
router_chain=router_chain,
|
@@ -231,13 +234,13 @@ def llm_infer(user_collection_name, prompt):
|
|
231 |
# Assuming a simplified caching mechanism for demonstration
|
232 |
chain_cache = {}
|
233 |
|
234 |
-
def get_or_create_chain(user_collection_name, llm):
|
235 |
if 'default_chain' in chain_cache and 'router_chain' in chain_cache:
|
236 |
default_chain = chain_cache['default_chain']
|
237 |
router_chain = chain_cache['router_chain']
|
238 |
destination_chains = chain_cache['destination_chains']
|
239 |
else:
|
240 |
-
vectordb = get_vectordb_for_user_cached(user_collection_name) # User-specific vector database
|
241 |
sum_template = """
|
242 |
As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
|
243 |
|
|
|
2 |
import os
|
3 |
import re
|
4 |
import hashlib
|
5 |
+
import asyncio
|
6 |
|
7 |
from langchain.document_loaders import PyPDFLoader
|
8 |
|
|
|
48 |
|
49 |
|
50 |
# Modify vectordb initialization to be dynamic based on user_id
|
51 |
+
async def get_vectordb_for_user(user_collection_name):
|
52 |
# Get Chromadb location
|
53 |
CHROMADB_LOC = os.getenv('CHROMADB_LOC')
|
54 |
|
|
|
61 |
|
62 |
vectordb_cache = {}
|
63 |
|
64 |
+
async def get_vectordb_for_user_cached(user_collection_name):
|
65 |
if user_collection_name not in vectordb_cache:
|
66 |
+
vectordb_cache[user_collection_name] = await get_vectordb_for_user(user_collection_name)
|
67 |
return vectordb_cache[user_collection_name]
|
68 |
|
69 |
|
|
|
94 |
return(vectordb)
|
95 |
#return collection # Return the collection as the asset
|
96 |
|
97 |
+
|
98 |
+
# Assuming LlamaModelSingleton is updated to support async instantiation
|
99 |
class LlamaModelSingleton:
|
100 |
_instance = None
|
101 |
|
102 |
+
@classmethod
|
103 |
+
async def get_instance(cls):
|
104 |
if cls._instance is None:
|
105 |
+
cls._instance = cls._load_llm() # Assuming _load_llm is synchronous, if not, use an executor
|
106 |
+
return cls._instance
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def _load_llm():
|
110 |
+
print('Loading LLM model...')
|
111 |
+
model_path = os.getenv("MODEL_PATH")
|
112 |
+
llm = LlamaCpp(
|
113 |
+
model_path=model_path,
|
114 |
+
n_gpu_layers=-1,
|
115 |
+
n_batch=512,
|
116 |
+
temperature=0.1,
|
117 |
+
top_p=1,
|
118 |
+
max_tokens=2000,
|
119 |
+
)
|
120 |
+
print(f'Model loaded from {model_path}')
|
121 |
+
return llm
|
122 |
+
|
123 |
+
async def load_llm():
|
124 |
+
return await LlamaModelSingleton.get_instance()
|
125 |
+
|
|
|
126 |
|
127 |
|
128 |
|
129 |
#step 5, to instantiate once to create default_chain,router_chain,destination_chains into chain and set vectordb. so will not re-create per prompt
|
130 |
+
async def default_chain(llm, user_collection_name):
|
131 |
# Get Chromadb location
|
132 |
CHROMADB_LOC = os.getenv('CHROMADB_LOC')
|
133 |
|
134 |
+
vectordb = await get_vectordb_for_user_cached(user_collection_name) # Use the dynamic vectordb based on user_id
|
135 |
sum_template = """
|
136 |
As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
|
137 |
|
|
|
212 |
return default_chain,router_chain,destination_chains
|
213 |
|
214 |
# Adjust llm_infer to accept user_id and use it for user-specific processing
|
215 |
+
async def llm_infer(user_collection_name, prompt):
|
216 |
|
217 |
+
llm = await load_llm() # load_llm is singleton for entire system
|
218 |
|
219 |
+
vectordb = await get_vectordb_for_user_cached(user_collection_name) # Vector collection for each us.
|
220 |
|
221 |
+
default_chain, router_chain, destination_chains = await get_or_create_chain(user_collection_name, llm) # Now user-specific
|
222 |
|
223 |
chain = MultiPromptChain(
|
224 |
router_chain=router_chain,
|
|
|
234 |
# Assuming a simplified caching mechanism for demonstration
|
235 |
chain_cache = {}
|
236 |
|
237 |
+
async def get_or_create_chain(user_collection_name, llm):
|
238 |
if 'default_chain' in chain_cache and 'router_chain' in chain_cache:
|
239 |
default_chain = chain_cache['default_chain']
|
240 |
router_chain = chain_cache['router_chain']
|
241 |
destination_chains = chain_cache['destination_chains']
|
242 |
else:
|
243 |
+
vectordb = await get_vectordb_for_user_cached(user_collection_name) # User-specific vector database
|
244 |
sum_template = """
|
245 |
As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
|
246 |
|