document-answering / llm_model.py
pflooky's picture
Use gradio for document answering
8324134
import os
import requests
from huggingface_hub import InferenceClient
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import CTransformers
from langchain_core.vectorstores import VectorStoreRetriever
class LLMModel:
base_model = "TheBloke/Llama-2-7B-GGUF"
specific_model = "llama-2-7b.Q4_K_M.gguf"
token_model = "meta-llama/Llama-2-7b-hf"
llm_config = {'context_length': 2048, 'max_new_tokens': 1024, 'temperature': 0.3, 'top_p': 1.0}
question_answer_system_prompt = """You are a helpful question answer assistant. Given the following context and a question, provide a set of potential questions and answers.
Keep answers brief and well-structured. Do not give one word answers."""
final_assistant_system_prompt = """You are a helpful assistant. Given the following list of relevant questions and answers, generate an answer based on this list only.
Keep answers brief and well-structured. Do not give one word answers.
If the answer is not found in the list, kindly state "I don't know.". Don't try to make up an answer."""
template = """<s>[INST] <<SYS>>
You are a question answer assistant. Given the following context and a question, generate an answer based on this context only.
Keep answers brief and well-structured. Do not give one word answers.
If the answer is not found in the context, kindly state "I don't know.". Don't try to make up an answer.
<</SYS>>
Context: {context}
Question: Give me a step by step explanation of {question}[/INST]
Answer:"""
qa_chain_prompt = PromptTemplate.from_template(template)
retriever = None
hf_token = os.getenv('HF_TOKEN')
api_url = os.getenv('API_URL')
headers = {"Authorization": f"Bearer {hf_token}"}
client = InferenceClient(api_url)
# llm = CTransformers(model=base_model, model_file=specific_model, config=llm_config, hf=True)
llm = None
def __init__(self, retriever: VectorStoreRetriever):
self.retriever = retriever
def create_qa_chain(self):
return RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": self.qa_chain_prompt},
)
def format_retrieved_docs(self, docs):
all_docs = []
for doc in docs:
if "source" in doc.metadata:
all_docs.append(f"""Document: {doc.metadata['source']}\nContent: {doc.page_content}\n\n""")
return all_docs
def format_query(self, question, context, system_prompt):
prompt = f"""[INST] {system_prompt}
Context: {context}
Question: Give me a step by step explanation of {question}[/INST]"""
return prompt
def format_question(self, question):
relevant_docs = self.retriever.get_relevant_documents(question)
formatted_docs = self.format_retrieved_docs(relevant_docs)
return self.format_query(question, formatted_docs, self.final_assistant_system_prompt)
def get_potential_question_answer(self, document_chunk: str):
prompt = self.format_query("potential questions and answers.", document_chunk, self.question_answer_system_prompt)
return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4)
def answer_question_inference_text_gen(self, question):
prompt = self.format_question(question)
return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4)
def answer_question_inference(self, question):
relevant_docs = self.retriever.get_relevant_documents(question)
formatted_docs = "".join(self.format_retrieved_docs(relevant_docs))
if not formatted_docs:
return "No uploaded documents. Please try upload a document on the left side."
else:
print(formatted_docs)
return self.client.question_answering(question=question, context=formatted_docs)
def answer_question_api(self, question):
formatted_prompt = self.format_question(question)
resp = requests.post(self.api_url, headers=self.headers, json={"inputs": formatted_prompt}, stream=True)
for c in resp.iter_content():
yield c