Spaces:
Runtime error
Runtime error
from openai import OpenAI | |
import cohere | |
from qdrant_client import models | |
from src.prompts import RAG_CONTEXT_TEMPLATE | |
class Retriever: | |
"""Retriever class for retrieving documents from the database | |
For retrieving documents, the following steps are performed: | |
1. Create an embedding for the query | |
2. Get n documents from the database based on the query and filters (Mixed retrieval) | |
3. Rerank the documents based on the query and select top k documents, where k << n (ReRanking) | |
4. Create a context from the selected documents | |
""" | |
def __init__(self, embedding_model, llm_model, rerank_model, db_client, db_collection='hotels'): | |
self.db_collection = db_collection | |
self.db_client = db_client | |
self.rerank_model = rerank_model | |
self.openai_client = OpenAI() | |
self.co = cohere.Client() | |
self.embedding_model = embedding_model | |
self.llm_model = llm_model | |
self.max_retrieved_docs = 13 | |
def _get_documents(self, query, top_k, city, price, rating): | |
"""Retrieve top n documents from the database based on the query and filters | |
Args: | |
query (str): query | |
top_k (int): number of documents to retrieve | |
city (str): city name | |
price (str): price range | |
rating (float): rating | |
Returns: | |
list: list of documents | |
""" | |
embedding = self.openai_client.embeddings.create(input=query, model=self.embedding_model) | |
filtr = [] | |
if city: | |
filtr.append(models.FieldCondition(key="city", match=models.MatchValue(value=city))) | |
if price: | |
filtr.append(models.FieldCondition(key="price", match=models.MatchValue(value=price))) | |
if rating: | |
filtr.append(models.FieldCondition(key="rating", range=models.Range(gte=rating))) | |
response = self.db_client.search( | |
collection_name=self.db_collection, | |
query_vector=embedding.data[0].embedding, | |
limit=top_k, | |
query_filter=models.Filter( | |
must=filtr | |
), | |
) | |
return response | |
def _get_context(self, docs): | |
"""Create a context from the retrieved documents | |
Args: | |
docs (list): list of documents | |
Returns: | |
str: context | |
""" | |
context = '' | |
for i, doc in enumerate(docs, 1): | |
context += RAG_CONTEXT_TEMPLATE.format(id=i, hotel_name=doc.payload['hotel_name'], description=doc.payload['description']) | |
return context | |
def _reranker(self, docs, query, top_k): | |
"""Rerank the retrieved documents using Cohere based on the query and select top k documents | |
Args: | |
docs (list): list of documents | |
query (str): query | |
top_k (int): number of documents to select | |
Returns: | |
list: list of reranked documents | |
""" | |
texts = [doc.payload['description'] for doc in docs] | |
rerank_hits = self.co.rerank(query=query, documents=texts, top_n=top_k, model=self.rerank_model) | |
result = [docs[hit.index] for hit in rerank_hits[:top_k]] | |
return result | |
def __call__(self, query, top_k=3, city=None, price=None, rating=None): | |
docs = self._get_documents(query, top_k=max(self.max_retrieved_docs, top_k), city=city, price=price, rating=rating) | |
if len(docs) == 0: | |
return 'There are no such hotels' | |
docs = self._reranker(docs, query, top_k) | |
context = self._get_context(docs) | |
return context | |