Spaces:
Runtime error
Runtime error
# https://python.langchain.com/docs/modules/chains/how_to/custom_chain | |
# Including reformulation of the question in the chain | |
import json | |
from langchain import PromptTemplate, LLMChain | |
from langchain.chains import QAWithSourcesChain | |
from langchain.chains import TransformChain, SequentialChain | |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
from anyqa.prompts import answer_prompt, reformulation_prompt | |
from anyqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain | |
def load_qa_chain_with_docs(llm): | |
"""Load a QA chain with documents. | |
Useful when you already have retrieved docs | |
To be called with this input | |
``` | |
output = chain({ | |
"question":query, | |
"audience":"experts scientists", | |
"docs":docs, | |
"language":"English", | |
}) | |
``` | |
""" | |
qa_chain = load_combine_documents_chain(llm) | |
chain = QAWithSourcesChain( | |
input_docs_key="docs", | |
combine_documents_chain=qa_chain, | |
return_source_documents=True, | |
) | |
return chain | |
def load_combine_documents_chain(llm): | |
prompt = PromptTemplate( | |
template=answer_prompt, | |
input_variables=["summaries", "question", "audience", "language"], | |
) | |
qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt) | |
return qa_chain | |
def load_qa_chain_with_text(llm): | |
prompt = PromptTemplate( | |
template=answer_prompt, | |
input_variables=["question", "audience", "language", "summaries"], | |
) | |
qa_chain = LLMChain(llm=llm, prompt=prompt) | |
return qa_chain | |
def load_qa_chain(retriever, llm_reformulation, llm_answer): | |
reformulation_chain = load_reformulation_chain(llm_reformulation) | |
answer_chain = load_qa_chain_with_retriever(retriever, llm_answer) | |
qa_chain = SequentialChain( | |
chains=[reformulation_chain, answer_chain], | |
input_variables=["query", "audience"], | |
output_variables=["answer", "question", "language", "source_documents"], | |
return_all=True, | |
verbose=True, | |
) | |
return qa_chain | |
def load_reformulation_chain(llm): | |
prompt = PromptTemplate( | |
template=reformulation_prompt, | |
input_variables=["query"], | |
) | |
reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json") | |
# Parse the output | |
def parse_output(output): | |
query = output["query"] | |
print("output", output) | |
json_output = json.loads(output["json"]) | |
question = json_output.get("question", query) | |
language = json_output.get("language", "English") | |
return { | |
"question": question, | |
"language": language, | |
} | |
transform_chain = TransformChain( | |
input_variables=["json"], | |
output_variables=["question", "language"], | |
transform=parse_output, | |
) | |
reformulation_chain = SequentialChain( | |
chains=[reformulation_chain, transform_chain], | |
input_variables=["query"], | |
output_variables=["question", "language"], | |
) | |
return reformulation_chain | |
def load_qa_chain_with_retriever(retriever, llm): | |
qa_chain = load_combine_documents_chain(llm) | |
# This could be improved by providing a document prompt to avoid modifying page_content in the docs | |
# See here https://github.com/langchain-ai/langchain/issues/3523 | |
answer_chain = CustomRetrievalQAWithSourcesChain( | |
combine_documents_chain=qa_chain, | |
retriever=retriever, | |
return_source_documents=True, | |
verbose=True, | |
fallback_answer="**⚠️ No relevant passages found in the sources, you may want to ask a more specific question.**", | |
) | |
return answer_chain | |