sabazo commited on
Commit
f48c420
2 Parent(s): 58c2582 3610691

Merge pull request #6 from almutareb/hybrid_search

Browse files
Files changed (1) hide show
  1. rag_app/hybrid_search.py +63 -0
rag_app/hybrid_search.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from langchain_community.vectorstores import FAISS
3
+ from dotenv import load_dotenv
4
+ import os
5
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
6
+ from langchain.retrievers import EnsembleRetriever
7
+ from langchain_community.retrievers import BM25Retriever
8
+
9
+
10
+ def get_hybrid_search_results(query:str,
11
+ path_to_db:str,
12
+ embedding_model:str,
13
+ hf_api_key:str,
14
+ num_docs:int=5) -> list:
15
+ """ Uses an ensemble retriever of BM25 and FAISS to return k num documents
16
+
17
+ Args:
18
+ query (str): The search query
19
+ path_to_db (str): Path to the vectorstore database
20
+ embedding_model (str): Embedding model used in the vector store
21
+ num_docs (int): Number of documents to return
22
+
23
+ Returns
24
+ List of documents
25
+
26
+ """
27
+
28
+ embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
29
+ model_name=embedding_model)
30
+ # Load the vectorstore database
31
+ db = FAISS.load_local(folder_path=path_to_db,
32
+ embeddings=embeddings,
33
+ allow_dangerous_deserialization=True)
34
+
35
+ all_docs = db.similarity_search("", k=db.index.ntotal)
36
+
37
+ bm25_retriever = BM25Retriever.from_documents(all_docs)
38
+ bm25_retriever.k = num_docs # How many results you want
39
+
40
+ faiss_retriever = db.as_retriever(search_kwargs={'k': num_docs})
41
+
42
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever],
43
+ weights=[0.5,0.5])
44
+
45
+ results = ensemble_retriever.invoke(input=query)
46
+ return results
47
+
48
+
49
+ if __name__ == "__main__":
50
+ query = "Haustierversicherung"
51
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
52
+ EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
53
+
54
+ path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'
55
+
56
+ results = get_hybrid_search_results(query=query,
57
+ path_to_db=path_to_vector_db,
58
+ embedding_model=EMBEDDING_MODEL,
59
+ hf_api_key=HUGGINGFACEHUB_API_TOKEN)
60
+
61
+ for doc in results:
62
+ print(doc)
63
+ print()