sabazo commited on
Commit
429c288
2 Parent(s): 6d93a64 af07445

Merge pull request #19 from almutareb/fetch_vectortosre_hfspace

Browse files
rag_app/loading_data/load_S3_vector_stores.py CHANGED
@@ -32,41 +32,43 @@ embeddings = SentenceTransformerEmbeddings(model_name=model_name)
32
 
33
  ## FAISS
34
  def get_faiss_vs():
35
- # Initialize an S3 client with unsigned configuration for public access
36
- s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
 
37
 
38
- # Define the destination for the downloaded file
39
- VS_DESTINATION = FAISS_INDEX_PATH + ".zip"
40
- try:
41
- # Download the pre-prepared vectorized index from the S3 bucket
42
- print("Downloading the pre-prepared FAISS vectorized index from S3...")
43
- s3.download_file(S3_LOCATION, FAISS_VS_NAME, VS_DESTINATION)
44
 
45
- # Extract the downloaded zip file
46
- with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
47
- zip_ref.extractall('./vectorstore/')
48
- print("Download and extraction completed.")
49
- return FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
50
-
51
- except Exception as e:
52
- print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
53
- #faissdb = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
54
 
55
 
56
  ## Chroma DB
57
  def get_chroma_vs():
58
- # Initialize an S3 client with unsigned configuration for public access
59
- s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
 
60
 
61
- VS_DESTINATION = CHROMA_DIRECTORY+".zip"
62
- try:
63
- # Download the pre-prepared vectorized index from the S3 bucket
64
- print("Downloading the pre-prepared chroma vectorstore from S3...")
65
- s3.download_file(S3_LOCATION, CHROMA_VS_NAME, VS_DESTINATION)
66
- with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
67
- zip_ref.extractall('./vectorstore/')
68
- print("Download and extraction completed.")
69
- chromadb = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=embeddings)
70
- #chromadb.get()
71
- except Exception as e:
72
- print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
 
32
 
33
  ## FAISS
34
  def get_faiss_vs():
35
+ if os.listdir(FAISS_INDEX_PATH) == 0:
36
+ # Initialize an S3 client with unsigned configuration for public access
37
+ s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
38
 
39
+ # Define the destination for the downloaded file
40
+ VS_DESTINATION = FAISS_INDEX_PATH + ".zip"
41
+ try:
42
+ # Download the pre-prepared vectorized index from the S3 bucket
43
+ print("Downloading the pre-prepared FAISS vectorized index from S3...")
44
+ s3.download_file(S3_LOCATION, FAISS_VS_NAME, VS_DESTINATION)
45
 
46
+ # Extract the downloaded zip file
47
+ with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
48
+ zip_ref.extractall('./vectorstore/')
49
+ print("Download and extraction completed.")
50
+ return FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
51
+
52
+ except Exception as e:
53
+ print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
54
+ #faissdb = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
55
 
56
 
57
  ## Chroma DB
58
  def get_chroma_vs():
59
+ if os.listdir(CHROMA_DIRECTORY) == 0:
60
+ # Initialize an S3 client with unsigned configuration for public access
61
+ s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
62
 
63
+ VS_DESTINATION = CHROMA_DIRECTORY+".zip"
64
+ try:
65
+ # Download the pre-prepared vectorized index from the S3 bucket
66
+ print("Downloading the pre-prepared chroma vectorstore from S3...")
67
+ s3.download_file(S3_LOCATION, CHROMA_VS_NAME, VS_DESTINATION)
68
+ with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
69
+ zip_ref.extractall('./vectorstore/')
70
+ print("Download and extraction completed.")
71
+ chromadb = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=embeddings)
72
+ #chromadb.get()
73
+ except Exception as e:
74
+ print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
rag_app/structured_tools/structured_tools.py CHANGED
@@ -8,7 +8,7 @@ from langchain_community.embeddings.sentence_transformer import (
8
  )
9
  from langchain_community.vectorstores import Chroma
10
  import ast
11
-
12
  import chromadb
13
 
14
  from rag_app.utils.utils import (
@@ -23,6 +23,8 @@ import os
23
 
24
  persist_directory = os.getenv('VECTOR_DATABASE_LOCATION')
25
  embedding_model = os.getenv("EMBEDDING_MODEL")
 
 
26
 
27
  @tool
28
  def memory_search(query:str) -> str:
 
8
  )
9
  from langchain_community.vectorstores import Chroma
10
  import ast
11
+ from rag_app.loading_data.load_S3_vector_stores import get_chroma_vs
12
  import chromadb
13
 
14
  from rag_app.utils.utils import (
 
23
 
24
  persist_directory = os.getenv('VECTOR_DATABASE_LOCATION')
25
  embedding_model = os.getenv("EMBEDDING_MODEL")
26
+ if os.listdir(persist_directory) == 0:
27
+ get_chroma_vs()
28
 
29
  @tool
30
  def memory_search(query:str) -> str: