Spaces:
Runtime error
Runtime error
File size: 3,670 Bytes
f0fc5f8 cc2ce8c f0fc5f8 93decd4 f0fc5f8 3d561c7 93decd4 3d561c7 6e28a81 3d561c7 cde6d5c 3d561c7 6e28a81 3d561c7 6e28a81 3d561c7 93decd4 cde6d5c 93decd4 cde6d5c 93decd4 cde6d5c 6e28a81 3d561c7 f0fc5f8 787d3cb 6e28a81 f0fc5f8 6e28a81 93decd4 f0fc5f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
# https://python.langchain.com/docs/modules/chains/how_to/custom_chain
# Including reformulation of the question in the chain
import json
from langchain import PromptTemplate, LLMChain
from langchain.chains import QAWithSourcesChain
from langchain.chains import TransformChain, SequentialChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from anyqa.prompts import answer_prompt, reformulation_prompt
from anyqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
def load_qa_chain_with_docs(llm):
"""Load a QA chain with documents.
Useful when you already have retrieved docs
To be called with this input
```
output = chain({
"question":query,
"audience":"experts scientists",
"docs":docs,
"language":"English",
})
```
"""
qa_chain = load_combine_documents_chain(llm)
chain = QAWithSourcesChain(
input_docs_key="docs",
combine_documents_chain=qa_chain,
return_source_documents=True,
)
return chain
def load_combine_documents_chain(llm):
prompt = PromptTemplate(
template=answer_prompt,
input_variables=["summaries", "question", "audience", "language"],
)
qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
return qa_chain
def load_qa_chain_with_text(llm):
prompt = PromptTemplate(
template=answer_prompt,
input_variables=["question", "audience", "language", "summaries"],
)
qa_chain = LLMChain(llm=llm, prompt=prompt)
return qa_chain
def load_qa_chain(retriever, llm_reformulation, llm_answer):
reformulation_chain = load_reformulation_chain(llm_reformulation)
answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)
qa_chain = SequentialChain(
chains=[reformulation_chain, answer_chain],
input_variables=["query", "audience"],
output_variables=["answer", "question", "language", "source_documents"],
return_all=True,
verbose=True,
)
return qa_chain
def load_reformulation_chain(llm):
prompt = PromptTemplate(
template=reformulation_prompt,
input_variables=["query"],
)
reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")
# Parse the output
def parse_output(output):
query = output["query"]
print("output", output)
json_output = json.loads(output["json"])
question = json_output.get("question", query)
language = json_output.get("language", "English")
return {
"question": question,
"language": language,
}
transform_chain = TransformChain(
input_variables=["json"],
output_variables=["question", "language"],
transform=parse_output,
)
reformulation_chain = SequentialChain(
chains=[reformulation_chain, transform_chain],
input_variables=["query"],
output_variables=["question", "language"],
)
return reformulation_chain
def load_qa_chain_with_retriever(retriever, llm):
qa_chain = load_combine_documents_chain(llm)
# This could be improved by providing a document prompt to avoid modifying page_content in the docs
# See here https://github.com/langchain-ai/langchain/issues/3523
answer_chain = CustomRetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=retriever,
return_source_documents=True,
verbose=True,
fallback_answer="**⚠️ No relevant passages found in the sources, you may want to ask a more specific question.**",
)
return answer_chain
|