Spaces:
Sleeping
Sleeping
Merge pull request #6 from almutareb/hybrid_search
Browse files- rag_app/hybrid_search.py +63 -0
rag_app/hybrid_search.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from langchain_community.vectorstores import FAISS
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import os
|
5 |
+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
6 |
+
from langchain.retrievers import EnsembleRetriever
|
7 |
+
from langchain_community.retrievers import BM25Retriever
|
8 |
+
|
9 |
+
|
10 |
+
def get_hybrid_search_results(query:str,
|
11 |
+
path_to_db:str,
|
12 |
+
embedding_model:str,
|
13 |
+
hf_api_key:str,
|
14 |
+
num_docs:int=5) -> list:
|
15 |
+
""" Uses an ensemble retriever of BM25 and FAISS to return k num documents
|
16 |
+
|
17 |
+
Args:
|
18 |
+
query (str): The search query
|
19 |
+
path_to_db (str): Path to the vectorstore database
|
20 |
+
embedding_model (str): Embedding model used in the vector store
|
21 |
+
num_docs (int): Number of documents to return
|
22 |
+
|
23 |
+
Returns
|
24 |
+
List of documents
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
|
29 |
+
model_name=embedding_model)
|
30 |
+
# Load the vectorstore database
|
31 |
+
db = FAISS.load_local(folder_path=path_to_db,
|
32 |
+
embeddings=embeddings,
|
33 |
+
allow_dangerous_deserialization=True)
|
34 |
+
|
35 |
+
all_docs = db.similarity_search("", k=db.index.ntotal)
|
36 |
+
|
37 |
+
bm25_retriever = BM25Retriever.from_documents(all_docs)
|
38 |
+
bm25_retriever.k = num_docs # How many results you want
|
39 |
+
|
40 |
+
faiss_retriever = db.as_retriever(search_kwargs={'k': num_docs})
|
41 |
+
|
42 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever],
|
43 |
+
weights=[0.5,0.5])
|
44 |
+
|
45 |
+
results = ensemble_retriever.invoke(input=query)
|
46 |
+
return results
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
query = "Haustierversicherung"
|
51 |
+
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
52 |
+
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
53 |
+
|
54 |
+
path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'
|
55 |
+
|
56 |
+
results = get_hybrid_search_results(query=query,
|
57 |
+
path_to_db=path_to_vector_db,
|
58 |
+
embedding_model=EMBEDDING_MODEL,
|
59 |
+
hf_api_key=HUGGINGFACEHUB_API_TOKEN)
|
60 |
+
|
61 |
+
for doc in results:
|
62 |
+
print(doc)
|
63 |
+
print()
|