# from get_db_retriever import get_db_retriever from pathlib import Path from langchain_community.vectorstores import FAISS from dotenv import load_dotenv import os from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings import requests from langchain_community.vectorstores import Chroma load_dotenv() def get_reranked_docs_faiss(query:str, path_to_db:str, embedding_model:str, hf_api_key:str, num_docs:int=5) -> list: """ Re-ranks the similarity search results and returns top-k highest ranked docs Args: query (str): The search query path_to_db (str): Path to the vectorstore database embedding_model (str): Embedding model used in the vector store num_docs (int): Number of documents to return Returns: A list of documents with the highest rank """ assert num_docs <= 10, "num_docs should be less than similarity search results" embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key, model_name=embedding_model) # Load the vectorstore database db = FAISS.load_local(folder_path=path_to_db, embeddings=embeddings, allow_dangerous_deserialization=True) # Get 10 documents based on similarity search docs = db.similarity_search(query=query, k=10) # Add the page_content, description and title together passages = [doc.page_content + "\n" + doc.metadata.get('title', "") +"\n"+ doc.metadata.get('description', "") for doc in docs] # Prepare the payload inputs = [{"text": query, "text_pair": passage} for passage in passages] API_URL = "https://api-inference.huggingface.co/models/deepset/gbert-base-germandpr-reranking" headers = {"Authorization": f"Bearer {hf_api_key}"} response = requests.post(API_URL, headers=headers, json=inputs) scores = response.json() try: relevance_scores = [item[1]['score'] for item in scores] except ValueError as e: print('Could not get the relevance_scores -> something might be wrong with the json output') return if relevance_scores: ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True) top_k_results = ranked_results[:num_docs] return [doc for doc, _, _ in top_k_results] def get_reranked_docs_chroma(query:str, path_to_db:str, embedding_model:str, hf_api_key:str, reranking_hf_url:str = "https://api-inference.huggingface.co/models/sentence-transformers/all-mpnet-base-v2", num_docs:int=5) -> list: """ Re-ranks the similarity search results and returns top-k highest ranked docs Args: query (str): The search query path_to_db (str): Path to the vectorstore database embedding_model (str): Embedding model used in the vector store num_docs (int): Number of documents to return Returns: A list of documents with the highest rank """ embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key, model_name=embedding_model) # Load the vectorstore database db = Chroma(persist_directory=path_to_db, embedding_function=embeddings) # Get k documents based on similarity search sim_docs = db.similarity_search(query=query, k=10) passages = [doc.page_content for doc in sim_docs] # Prepare the payload payload = {"inputs": {"source_sentence": query, "sentences": passages}} headers = {"Authorization": f"Bearer {hf_api_key}"} response = requests.post(url=reranking_hf_url, headers=headers, json=payload) print(f'{response = }') if response.status_code != 200: print('Something went wrong with the response') return similarity_scores = response.json() ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True) top_k_results = ranked_results[:num_docs] return [doc for doc, _, _ in top_k_results] if __name__ == "__main__": HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1" project_dir = Path().cwd().parent path_to_vector_db = str(project_dir/'vectorstore/chroma-zurich-mpnet-1500') assert Path(path_to_vector_db).exists(), "Cannot access path_to_vector_db " query = "I'm looking for student insurance" re_ranked_docs = get_reranked_docs_chroma(query=query, path_to_db= path_to_vector_db, embedding_model=EMBEDDING_MODEL, hf_api_key=HUGGINGFACEHUB_API_TOKEN) print(f"{re_ranked_docs=}")