Spaces:
Sleeping
Sleeping
File size: 2,373 Bytes
c9de890 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from pathlib import Path
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv
import os
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
def get_hybrid_search_results(query:str,
path_to_db:str,
embedding_model:str,
hf_api_key:str,
num_docs:int=5) -> list:
""" Uses an ensemble retriever of BM25 and FAISS to return k num documents
Args:
query (str): The search query
path_to_db (str): Path to the vectorstore database
embedding_model (str): Embedding model used in the vector store
num_docs (int): Number of documents to return
Returns
List of documents
"""
embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
model_name=embedding_model)
# Load the vectorstore database
db = FAISS.load_local(folder_path=path_to_db,
embeddings=embeddings,
allow_dangerous_deserialization=True)
all_docs = db.similarity_search("", k=db.index.ntotal)
bm25_retriever = BM25Retriever.from_documents(all_docs)
bm25_retriever.k = num_docs # How many results you want
faiss_retriever = db.as_retriever(search_kwargs={'k': num_docs})
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever],
weights=[0.5,0.5])
results = ensemble_retriever.invoke(input=query)
return results
if __name__ == "__main__":
query = "Haustierversicherung"
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'
results = get_hybrid_search_results(query=query,
path_to_db=path_to_vector_db,
embedding_model=EMBEDDING_MODEL,
hf_api_key=HUGGINGFACEHUB_API_TOKEN)
for doc in results:
print(doc)
print()
|