sabazo commited on
Commit
8eb79b5
2 Parent(s): 8ad1e03 524e9f1

Merge pull request #4 from almutareb/reranking

Browse files
core-langchain-rag.py CHANGED
@@ -214,7 +214,7 @@ def generate_qa_retriever(history: dict, question: str, llm_model:HuggingFaceEnd
214
  template = """
215
  You are a friendly insurance product advisor, your task is to help customers find the best products from Württembergische GmbH.\
216
  You help the user find the answers to all his questions. Answer in short and simple terms and offer to explain the product and terms to the user.\
217
- Respond only using the provided context (delimited by <ctx></ctx>) and only in German or Englisch, depending on the question's language.
218
  Use the chat history (delimited by <hs></hs>) to help find the best product for the user:
219
  ------
220
  <ctx>
 
214
  template = """
215
  You are a friendly insurance product advisor, your task is to help customers find the best products from Württembergische GmbH.\
216
  You help the user find the answers to all his questions. Answer in short and simple terms and offer to explain the product and terms to the user.\
217
+ Respond only using the provided context (delimited by <ctx></ctx>) and only in German or English, depending on the question's language.
218
  Use the chat history (delimited by <hs></hs>) to help find the best product for the user:
219
  ------
220
  <ctx>
rag_app/__init__.py ADDED
File without changes
rag_app/get_db_retriever.py CHANGED
@@ -26,4 +26,5 @@ def get_db_retriever(vector_db:str=None):
26
 
27
  retriever = db.as_retriever()
28
 
29
- return retriever
 
 
26
 
27
  retriever = db.as_retriever()
28
 
29
+ return retriever
30
+
rag_app/loading_data/load_S3_vector_stores.py CHANGED
@@ -10,6 +10,7 @@ from dotenv import load_dotenv
10
  import os
11
  import sys
12
  import logging
 
13
 
14
  # Load environment variables from a .env file
15
  config = load_dotenv(".env")
@@ -38,6 +39,7 @@ def get_faiss_vs():
38
 
39
  # Define the destination for the downloaded file
40
  VS_DESTINATION = FAISS_INDEX_PATH + ".zip"
 
41
  try:
42
  # Download the pre-prepared vectorized index from the S3 bucket
43
  print("Downloading the pre-prepared FAISS vectorized index from S3...")
@@ -51,7 +53,32 @@ def get_faiss_vs():
51
 
52
  except Exception as e:
53
  print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
54
- #faissdb = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  ## Chroma DB
@@ -70,4 +97,10 @@ def get_chroma_vs():
70
  chromadb = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=embeddings)
71
  chromadb.get()
72
  except Exception as e:
73
- print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
 
 
 
 
 
 
 
10
  import os
11
  import sys
12
  import logging
13
+ from pathlib import Path
14
 
15
  # Load environment variables from a .env file
16
  config = load_dotenv(".env")
 
39
 
40
  # Define the destination for the downloaded file
41
  VS_DESTINATION = FAISS_INDEX_PATH + ".zip"
42
+
43
  try:
44
  # Download the pre-prepared vectorized index from the S3 bucket
45
  print("Downloading the pre-prepared FAISS vectorized index from S3...")
 
53
 
54
  except Exception as e:
55
  print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
56
+ # faissdb = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
57
+
58
+
59
+ def get_faiss_vs_from_s3(s3_loc:str,
60
+ s3_vs_name:str,
61
+ vs_dir:str='vectorstore') -> None:
62
+ """ Download the FAISS vector store from S3 bucket
63
+
64
+ Args:
65
+ s3_loc (str): Name of the S3 bucket
66
+ s3_vs_name (str): Name of the file to be downloaded
67
+ vs_dir (str): The name of the directory where the file is to be saved
68
+ """
69
+ # Initialize an S3 client with unsigned configuration for public access
70
+ s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
71
+ # Destination folder
72
+ vs_dir_path = Path("..") / vs_dir
73
+ assert vs_dir_path.is_dir(), "Cannot find vs_dir folder"
74
+ try:
75
+ vs_destination = Path("..") / vs_dir / "faiss-insurance-agent-500.zip"
76
+ s3.download_file(s3_loc, s3_vs_name, vs_destination)
77
+ # Extract the downloaded zip file
78
+ with zipfile.ZipFile(file=vs_destination, mode='r') as zip_ref:
79
+ zip_ref.extractall(path=vs_dir_path.as_posix())
80
+ except Exception as e:
81
+ print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
82
 
83
 
84
  ## Chroma DB
 
97
  chromadb = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=embeddings)
98
  chromadb.get()
99
  except Exception as e:
100
+ print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
101
+
102
+
103
+ if __name__ == "__main__":
104
+ # get_faiss_vs_from_s3(s3_loc=S3_LOCATION, s3_vs_name=FAISS_VS_NAME)
105
+ pass
106
+
rag_app/reranking.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from get_db_retriever import get_db_retriever
2
+ from pathlib import Path
3
+ from langchain_community.vectorstores import FAISS
4
+ from dotenv import load_dotenv
5
+ import os
6
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
+ import requests
8
+
9
+ load_dotenv()
10
+
11
+
12
+ def get_reranked_docs(query:str,
13
+ path_to_db:str,
14
+ embedding_model:str,
15
+ hf_api_key:str,
16
+ num_docs:int=5) -> list:
17
+ """ Re-ranks the similarity search results and returns top-k highest ranked docs
18
+
19
+ Args:
20
+ query (str): The search query
21
+ path_to_db (str): Path to the vectorstore database
22
+ embedding_model (str): Embedding model used in the vector store
23
+ num_docs (int): Number of documents to return
24
+
25
+ Returns: A list of documents with the highest rank
26
+ """
27
+ assert num_docs <= 10, "num_docs should be less than similarity search results"
28
+
29
+ embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
30
+ model_name=embedding_model)
31
+ # Load the vectorstore database
32
+ db = FAISS.load_local(folder_path=path_to_db,
33
+ embeddings=embeddings,
34
+ allow_dangerous_deserialization=True)
35
+
36
+ # Get 10 documents based on similarity search
37
+ docs = db.similarity_search(query=query, k=10)
38
+
39
+ # Add the page_content, description and title together
40
+ passages = [doc.page_content + "\n" + doc.metadata.get('title', "") +"\n"+ doc.metadata.get('description', "")
41
+ for doc in docs]
42
+
43
+ # Prepare the payload
44
+ inputs = [{"text": query, "text_pair": passage} for passage in passages]
45
+
46
+ API_URL = "https://api-inference.huggingface.co/models/deepset/gbert-base-germandpr-reranking"
47
+ headers = {"Authorization": f"Bearer {hf_api_key}"}
48
+
49
+ response = requests.post(API_URL, headers=headers, json=inputs)
50
+ scores = response.json()
51
+
52
+ try:
53
+ relevance_scores = [item[1]['score'] for item in scores]
54
+ except ValueError as e:
55
+ print('Could not get the relevance_scores -> something might be wrong with the json output')
56
+ return
57
+
58
+ if relevance_scores:
59
+ ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
60
+ top_k_results = ranked_results[:num_docs]
61
+ return [doc for doc, _, _ in top_k_results]
62
+
63
+
64
+ if __name__ == "__main__":
65
+
66
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
67
+ EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
68
+
69
+ path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'
70
+
71
+ query = "Ich möchte wissen, ob ich meine geriatrische Haustier-Eidechse versichern kann"
72
+
73
+ top_5_docs = get_reranked_docs(query=query,
74
+ path_to_db=path_to_vector_db,
75
+ embedding_model=EMBEDDING_MODEL,
76
+ hf_api_key=HUGGINGFACEHUB_API_TOKEN,
77
+ num_docs=5)
78
+
79
+ for i, doc in enumerate(top_5_docs):
80
+ print(f"{i}: {doc}\n")