Spaces:
Sleeping
Sleeping
from langchain.callbacks import FileCallbackHandler | |
from langchain_community.chat_models import ChatOllama | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from loguru import logger | |
from rag_101.retriever import ( | |
RAGException, | |
create_parent_retriever, | |
load_embedding_model, | |
load_pdf, | |
load_reranker_model, | |
retrieve_context, | |
) | |
class RAGClient: | |
embedding_model = load_embedding_model() | |
reranker_model = load_reranker_model() | |
def __init__(self, files, model="mistral"): | |
docs = load_pdf(files=files) | |
self.retriever = create_parent_retriever(docs, self.embedding_model) | |
llm = ChatOllama(model=model) | |
prompt_template = ChatPromptTemplate.from_template( | |
( | |
"Please answer the following question based on the provided `context` that follows the question.\n" | |
"Think step by step before coming to answer. If you do not know the answer then just say 'I do not know'\n" | |
"question: {question}\n" | |
"context: ```{context}```\n" | |
) | |
) | |
self.chain = prompt_template | llm | StrOutputParser() | |
def stream(self, query: str) -> dict: | |
try: | |
context, similarity_score = self.retrieve_context(query)[0] | |
context = context.page_content | |
if similarity_score < 0.005: | |
context = "This context is not confident. " + context | |
except RAGException as e: | |
context, similarity_score = e.args[0], 0 | |
logger.info(context) | |
for r in self.chain.stream({"context": context, "question": query}): | |
yield r | |
def retrieve_context(self, query: str): | |
return retrieve_context( | |
query, retriever=self.retriever, reranker_model=self.reranker_model | |
) | |
def generate(self, query: str) -> dict: | |
contexts = self.retrieve_context(query) | |
return { | |
"contexts": contexts, | |
"response": self.chain.invoke( | |
{"context": contexts[0][0].page_content, "question": query} | |
), | |
} |