Spaces:
Runtime error
Runtime error
import os | |
import requests | |
from huggingface_hub import InferenceClient | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain_community.llms import CTransformers | |
from langchain_core.vectorstores import VectorStoreRetriever | |
class LLMModel: | |
base_model = "TheBloke/Llama-2-7B-GGUF" | |
specific_model = "llama-2-7b.Q4_K_M.gguf" | |
token_model = "meta-llama/Llama-2-7b-hf" | |
llm_config = {'context_length': 2048, 'max_new_tokens': 1024, 'temperature': 0.3, 'top_p': 1.0} | |
question_answer_system_prompt = """You are a helpful question answer assistant. Given the following context and a question, provide a set of potential questions and answers. | |
Keep answers brief and well-structured. Do not give one word answers.""" | |
final_assistant_system_prompt = """You are a helpful assistant. Given the following list of relevant questions and answers, generate an answer based on this list only. | |
Keep answers brief and well-structured. Do not give one word answers. | |
If the answer is not found in the list, kindly state "I don't know.". Don't try to make up an answer.""" | |
template = """<s>[INST] <<SYS>> | |
You are a question answer assistant. Given the following context and a question, generate an answer based on this context only. | |
Keep answers brief and well-structured. Do not give one word answers. | |
If the answer is not found in the context, kindly state "I don't know.". Don't try to make up an answer. | |
<</SYS>> | |
Context: {context} | |
Question: Give me a step by step explanation of {question}[/INST] | |
Answer:""" | |
qa_chain_prompt = PromptTemplate.from_template(template) | |
retriever = None | |
hf_token = os.getenv('HF_TOKEN') | |
api_url = os.getenv('API_URL') | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
client = InferenceClient(api_url) | |
# llm = CTransformers(model=base_model, model_file=specific_model, config=llm_config, hf=True) | |
llm = None | |
def __init__(self, retriever: VectorStoreRetriever): | |
self.retriever = retriever | |
def create_qa_chain(self): | |
return RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": self.qa_chain_prompt}, | |
) | |
def format_retrieved_docs(self, docs): | |
all_docs = [] | |
for doc in docs: | |
if "source" in doc.metadata: | |
all_docs.append(f"""Document: {doc.metadata['source']}\nContent: {doc.page_content}\n\n""") | |
return all_docs | |
def format_query(self, question, context, system_prompt): | |
prompt = f"""[INST] {system_prompt} | |
Context: {context} | |
Question: Give me a step by step explanation of {question}[/INST]""" | |
return prompt | |
def format_question(self, question): | |
relevant_docs = self.retriever.get_relevant_documents(question) | |
formatted_docs = self.format_retrieved_docs(relevant_docs) | |
return self.format_query(question, formatted_docs, self.final_assistant_system_prompt) | |
def get_potential_question_answer(self, document_chunk: str): | |
prompt = self.format_query("potential questions and answers.", document_chunk, self.question_answer_system_prompt) | |
return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4) | |
def answer_question_inference_text_gen(self, question): | |
prompt = self.format_question(question) | |
return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4) | |
def answer_question_inference(self, question): | |
relevant_docs = self.retriever.get_relevant_documents(question) | |
formatted_docs = "".join(self.format_retrieved_docs(relevant_docs)) | |
if not formatted_docs: | |
return "No uploaded documents. Please try upload a document on the left side." | |
else: | |
print(formatted_docs) | |
return self.client.question_answering(question=question, context=formatted_docs) | |
def answer_question_api(self, question): | |
formatted_prompt = self.format_question(question) | |
resp = requests.post(self.api_url, headers=self.headers, json={"inputs": formatted_prompt}, stream=True) | |
for c in resp.iter_content(): | |
yield c | |