GitChat / rag_101 /client.py
kartavya23's picture
Upload 4 files
47ad957 verified
raw
history blame
2.22 kB
from langchain.callbacks import FileCallbackHandler
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from loguru import logger
from rag_101.retriever import (
RAGException,
create_parent_retriever,
load_embedding_model,
load_pdf,
load_reranker_model,
retrieve_context,
)
class RAGClient:
embedding_model = load_embedding_model()
reranker_model = load_reranker_model()
def __init__(self, files, model="mistral"):
docs = load_pdf(files=files)
self.retriever = create_parent_retriever(docs, self.embedding_model)
llm = ChatOllama(model=model)
prompt_template = ChatPromptTemplate.from_template(
(
"Please answer the following question based on the provided `context` that follows the question.\n"
"Think step by step before coming to answer. If you do not know the answer then just say 'I do not know'\n"
"question: {question}\n"
"context: ```{context}```\n"
)
)
self.chain = prompt_template | llm | StrOutputParser()
def stream(self, query: str) -> dict:
try:
context, similarity_score = self.retrieve_context(query)[0]
context = context.page_content
if similarity_score < 0.005:
context = "This context is not confident. " + context
except RAGException as e:
context, similarity_score = e.args[0], 0
logger.info(context)
for r in self.chain.stream({"context": context, "question": query}):
yield r
def retrieve_context(self, query: str):
return retrieve_context(
query, retriever=self.retriever, reranker_model=self.reranker_model
)
def generate(self, query: str) -> dict:
contexts = self.retrieve_context(query)
return {
"contexts": contexts,
"response": self.chain.invoke(
{"context": contexts[0][0].page_content, "question": query}
),
}