dtyago commited on
Commit
11fcf53
1 Parent(s): 045145e

Implemented async for performance gain

Browse files
Files changed (3) hide show
  1. app/api/userchat.py +1 -1
  2. app/main.py +4 -3
  3. app/utils/chat_rag.py +37 -34
app/api/userchat.py CHANGED
@@ -11,7 +11,7 @@ async def chat_with_llama(user_input: str = Body(..., embed=True), current_user:
11
  # Example logic for model inference (pseudo-code, adjust as necessary)
12
  try:
13
  user_id = current_user["user_id"]
14
- model_response = llm_infer(user_collection_name=sanitize_collection_name(user_id), prompt=user_input)
15
  # Optionally, store chat history
16
  # chromadb_face_helper.store_chat_history(user_id=current_user["user_id"], user_input=user_input, model_response=model_response)
17
  except Exception as e:
 
11
  # Example logic for model inference (pseudo-code, adjust as necessary)
12
  try:
13
  user_id = current_user["user_id"]
14
+ model_response = await llm_infer(user_collection_name=sanitize_collection_name(user_id), prompt=user_input)
15
  # Optionally, store chat history
16
  # chromadb_face_helper.store_chat_history(user_id=current_user["user_id"], user_input=user_input, model_response=model_response)
17
  except Exception as e:
app/main.py CHANGED
@@ -11,7 +11,7 @@ from admin import admin_functions as admin
11
  from utils.db import UserFaceEmbeddingFunction,ChromaDBFaceHelper
12
  from api import userlogin, userlogout, userchat, userupload
13
  from utils.db import ChromaDBFaceHelper
14
- from utils.chat_rag import load_llm
15
 
16
  app = FastAPI()
17
 
@@ -42,8 +42,9 @@ async def startup_event():
42
  chromadb_face_helper = ChromaDBFaceHelper(db_path) # Used by APIs
43
 
44
  # Perform any other startup tasks here
45
- # Load the LLM is a singleton class call
46
- load_llm()
 
47
 
48
  print(f"MODEL_PATH in main.py = {os.getenv('MODEL_PATH')} ")
49
 
 
11
  from utils.db import UserFaceEmbeddingFunction,ChromaDBFaceHelper
12
  from api import userlogin, userlogout, userchat, userupload
13
  from utils.db import ChromaDBFaceHelper
14
+ from utils.chat_rag import LlamaModelSingleton
15
 
16
  app = FastAPI()
17
 
 
42
  chromadb_face_helper = ChromaDBFaceHelper(db_path) # Used by APIs
43
 
44
  # Perform any other startup tasks here
45
+ # Preload the LLM model
46
+ await LlamaModelSingleton.get_instance()
47
+ print("LLM model loaded and ready.")
48
 
49
  print(f"MODEL_PATH in main.py = {os.getenv('MODEL_PATH')} ")
50
 
app/utils/chat_rag.py CHANGED
@@ -2,6 +2,7 @@
2
  import os
3
  import re
4
  import hashlib
 
5
 
6
  from langchain.document_loaders import PyPDFLoader
7
 
@@ -47,7 +48,7 @@ def sanitize_collection_name(email):
47
 
48
 
49
  # Modify vectordb initialization to be dynamic based on user_id
50
- def get_vectordb_for_user(user_collection_name):
51
  # Get Chromadb location
52
  CHROMADB_LOC = os.getenv('CHROMADB_LOC')
53
 
@@ -60,9 +61,9 @@ def get_vectordb_for_user(user_collection_name):
60
 
61
  vectordb_cache = {}
62
 
63
- def get_vectordb_for_user_cached(user_collection_name):
64
  if user_collection_name not in vectordb_cache:
65
- vectordb_cache[user_collection_name] = get_vectordb_for_user(user_collection_name)
66
  return vectordb_cache[user_collection_name]
67
 
68
 
@@ -93,42 +94,44 @@ def pdf_to_vec(filename, user_collection_name):
93
  return(vectordb)
94
  #return collection # Return the collection as the asset
95
 
 
 
96
  class LlamaModelSingleton:
97
  _instance = None
98
 
99
- def __new__(cls):
 
100
  if cls._instance is None:
101
- print('Loading LLM model...')
102
- cls._instance = super(LlamaModelSingleton, cls).__new__(cls)
103
-
104
- # Model loading logic
105
- model_path = os.getenv("MODEL_PATH")
106
- cls._instance.llm = LlamaCpp(
107
- #streaming = True,
108
- model_path=model_path,
109
- n_gpu_layers=-1,
110
- n_batch=512,
111
- temperature=0.1,
112
- top_p=1,
113
- #verbose=False,
114
- #callback_manager=callback_manager,
115
- max_tokens=2000,
116
- )
117
- print(f'Model loaded from {model_path}')
118
- return cls._instance.llm
119
-
120
-
121
- def load_llm():
122
- return LlamaModelSingleton()
123
 
124
 
125
 
126
  #step 5, to instantiate once to create default_chain,router_chain,destination_chains into chain and set vectordb. so will not re-create per prompt
127
- def default_chain(llm, user_collection_name):
128
  # Get Chromadb location
129
  CHROMADB_LOC = os.getenv('CHROMADB_LOC')
130
 
131
- vectordb = get_vectordb_for_user_cached(user_collection_name) # Use the dynamic vectordb based on user_id
132
  sum_template = """
133
  As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
134
 
@@ -209,13 +212,13 @@ def default_chain(llm, user_collection_name):
209
  return default_chain,router_chain,destination_chains
210
 
211
  # Adjust llm_infer to accept user_id and use it for user-specific processing
212
- def llm_infer(user_collection_name, prompt):
213
 
214
- llm = load_llm() # load_llm is singleton for entire system
215
 
216
- vectordb = get_vectordb_for_user_cached(user_collection_name) # Vector collection for each us.
217
 
218
- default_chain, router_chain, destination_chains = get_or_create_chain(user_collection_name, llm) # Now user-specific
219
 
220
  chain = MultiPromptChain(
221
  router_chain=router_chain,
@@ -231,13 +234,13 @@ def llm_infer(user_collection_name, prompt):
231
  # Assuming a simplified caching mechanism for demonstration
232
  chain_cache = {}
233
 
234
- def get_or_create_chain(user_collection_name, llm):
235
  if 'default_chain' in chain_cache and 'router_chain' in chain_cache:
236
  default_chain = chain_cache['default_chain']
237
  router_chain = chain_cache['router_chain']
238
  destination_chains = chain_cache['destination_chains']
239
  else:
240
- vectordb = get_vectordb_for_user_cached(user_collection_name) # User-specific vector database
241
  sum_template = """
242
  As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
243
 
 
2
  import os
3
  import re
4
  import hashlib
5
+ import asyncio
6
 
7
  from langchain.document_loaders import PyPDFLoader
8
 
 
48
 
49
 
50
  # Modify vectordb initialization to be dynamic based on user_id
51
+ async def get_vectordb_for_user(user_collection_name):
52
  # Get Chromadb location
53
  CHROMADB_LOC = os.getenv('CHROMADB_LOC')
54
 
 
61
 
62
  vectordb_cache = {}
63
 
64
+ async def get_vectordb_for_user_cached(user_collection_name):
65
  if user_collection_name not in vectordb_cache:
66
+ vectordb_cache[user_collection_name] = await get_vectordb_for_user(user_collection_name)
67
  return vectordb_cache[user_collection_name]
68
 
69
 
 
94
  return(vectordb)
95
  #return collection # Return the collection as the asset
96
 
97
+
98
+ # Assuming LlamaModelSingleton is updated to support async instantiation
99
  class LlamaModelSingleton:
100
  _instance = None
101
 
102
+ @classmethod
103
+ async def get_instance(cls):
104
  if cls._instance is None:
105
+ cls._instance = cls._load_llm() # Assuming _load_llm is synchronous, if not, use an executor
106
+ return cls._instance
107
+
108
+ @staticmethod
109
+ def _load_llm():
110
+ print('Loading LLM model...')
111
+ model_path = os.getenv("MODEL_PATH")
112
+ llm = LlamaCpp(
113
+ model_path=model_path,
114
+ n_gpu_layers=-1,
115
+ n_batch=512,
116
+ temperature=0.1,
117
+ top_p=1,
118
+ max_tokens=2000,
119
+ )
120
+ print(f'Model loaded from {model_path}')
121
+ return llm
122
+
123
+ async def load_llm():
124
+ return await LlamaModelSingleton.get_instance()
125
+
 
126
 
127
 
128
 
129
  #step 5, to instantiate once to create default_chain,router_chain,destination_chains into chain and set vectordb. so will not re-create per prompt
130
+ async def default_chain(llm, user_collection_name):
131
  # Get Chromadb location
132
  CHROMADB_LOC = os.getenv('CHROMADB_LOC')
133
 
134
+ vectordb = await get_vectordb_for_user_cached(user_collection_name) # Use the dynamic vectordb based on user_id
135
  sum_template = """
136
  As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
137
 
 
212
  return default_chain,router_chain,destination_chains
213
 
214
  # Adjust llm_infer to accept user_id and use it for user-specific processing
215
+ async def llm_infer(user_collection_name, prompt):
216
 
217
+ llm = await load_llm() # load_llm is singleton for entire system
218
 
219
+ vectordb = await get_vectordb_for_user_cached(user_collection_name) # Vector collection for each us.
220
 
221
+ default_chain, router_chain, destination_chains = await get_or_create_chain(user_collection_name, llm) # Now user-specific
222
 
223
  chain = MultiPromptChain(
224
  router_chain=router_chain,
 
234
  # Assuming a simplified caching mechanism for demonstration
235
  chain_cache = {}
236
 
237
+ async def get_or_create_chain(user_collection_name, llm):
238
  if 'default_chain' in chain_cache and 'router_chain' in chain_cache:
239
  default_chain = chain_cache['default_chain']
240
  router_chain = chain_cache['router_chain']
241
  destination_chains = chain_cache['destination_chains']
242
  else:
243
+ vectordb = await get_vectordb_for_user_cached(user_collection_name) # User-specific vector database
244
  sum_template = """
245
  As a machine learning education specialist, our expertise is pivotal in deepening the comprehension of complex machine learning concepts for both educators and students.
246