File size: 2,373 Bytes
c9de890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from pathlib import Path
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv
import os
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever


def get_hybrid_search_results(query:str,
                              path_to_db:str,
                              embedding_model:str,
                              hf_api_key:str,
                              num_docs:int=5) -> list:
    """ Uses an ensemble retriever of BM25 and FAISS to return k num documents

        Args:
            query (str): The search query
            path_to_db (str): Path to the vectorstore database
            embedding_model (str): Embedding model used in the vector store
            num_docs (int): Number of documents to return
        
        Returns
            List of documents
    
    """
    
    embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
                                                   model_name=embedding_model)
    # Load the vectorstore database
    db = FAISS.load_local(folder_path=path_to_db,
                          embeddings=embeddings,
                          allow_dangerous_deserialization=True)

    all_docs = db.similarity_search("", k=db.index.ntotal)

    bm25_retriever = BM25Retriever.from_documents(all_docs)
    bm25_retriever.k = num_docs  # How many results you want

    faiss_retriever = db.as_retriever(search_kwargs={'k': num_docs})

    ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever],
                                           weights=[0.5,0.5])
    
    results = ensemble_retriever.invoke(input=query) 
    return results


if __name__ == "__main__":
    query = "Haustierversicherung"
    HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
    EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
    
    path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'

    results = get_hybrid_search_results(query=query, 
                                    path_to_db=path_to_vector_db, 
                                    embedding_model=EMBEDDING_MODEL, 
                                    hf_api_key=HUGGINGFACEHUB_API_TOKEN)
    
    for doc in results:
        print(doc)
        print()