KrishnaKumar23 commited on
Commit
e797f63
β€’
1 Parent(s): 028692e

Changed LLM to Mixtral-8x7B-Instruct-v0.1

Browse files
Files changed (4) hide show
  1. app.py +165 -45
  2. llm_model.py +104 -83
  3. requirements.txt +2 -0
  4. static/temp.txt +0 -0
app.py CHANGED
@@ -3,11 +3,21 @@ from streamlit_lottie import st_lottie
3
  import fitz # PyMuPDF
4
  import requests
5
  import os, shutil
6
- import sidebar
7
  import llm_model
8
 
 
 
 
 
 
 
 
 
 
 
 
9
  @st.cache_data(experimental_allow_widgets=True)
10
- def index_document(uploaded_file):
11
 
12
  if uploaded_file is not None:
13
  # Specify the folder path where you want to store the uploaded file in the 'assets' folder
@@ -24,8 +34,9 @@ def index_document(uploaded_file):
24
  st.success(f"File '{file_name}' uploaded !")
25
 
26
  with st.spinner("Indexing document... This is a free CPU version and may take a while ⏳"):
27
- llm_model.create_vector_db(file_name, instructor_embeddings)
28
-
 
29
  return file_name
30
  else:
31
  return None
@@ -44,11 +55,135 @@ def is_query_valid(query: str) -> bool:
44
  return False
45
  return True
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Function to load model parameters
49
  @st.cache_resource()
50
  def load_model():
51
- return llm_model.load_model_params()
 
52
 
53
  st.set_page_config(page_title="Document QA Bot")
54
  lottie_book = load_lottieurl("https://assets4.lottiefiles.com/temp/lf20_aKAfIn.json")
@@ -56,44 +191,29 @@ st_lottie(lottie_book, speed=1, height=200, key="initial")
56
  # Place the title below the Lottie animation
57
  st.title("Document Q&A Bot πŸ€–")
58
 
 
 
59
  # Left Sidebar
60
- sidebar.sidebar()
61
- # st.sidebar.header("Upload PDF")
62
-
63
- # load model parameters
64
- llm, instructor_embeddings = load_model()
65
- # Upload file through Streamlit
66
- uploaded_file = st.file_uploader("Upload a file", type=["pdf", "doc", "docx", "txt"])
67
-
68
- filename = index_document(uploaded_file)
69
- print(filename)
70
-
71
- if not filename:
72
- st.stop()
73
-
74
-
75
- with st.form(key="qa_form"):
76
- query = st.text_area("Ask a question about the document")
77
- submit = st.form_submit_button("Submit")
78
-
79
- if submit:
80
- if not is_query_valid(query):
81
- st.stop()
82
-
83
- # Output Columns
84
- answer_col, sources_col = st.columns(2)
85
-
86
- qa_chain = llm_model.document_parser(instructor_embeddings, llm)
87
- result = qa_chain(query)
88
-
89
- with answer_col:
90
- st.markdown("#### Answer")
91
- st.markdown(result["result"])
92
-
93
- with sources_col:
94
- st.markdown("#### Sources")
95
- if not ("i don't know" in result["result"].lower()):
96
- for source in result["source_documents"]:
97
- st.markdown(source.page_content)
98
- st.markdown(source.metadata["source"])
99
- st.markdown("--------------------------")
 
3
  import fitz # PyMuPDF
4
  import requests
5
  import os, shutil
 
6
  import llm_model
7
 
8
+
9
+ SYSTEM_PROMPT = [
10
+ """
11
+ You are not Mistral AI, but rather a Q&A bot trained by Krishna Kumar while building a cool side project based on RAG. Whenever asked, you need to answer as Q&A bot.
12
+ """,
13
+ """You are a RAG based Document Q&A bot. Based on the input prompt and retrieved context from the vector database you will answer questions that are closer to the context.
14
+ If no context was found then, say "I don't know" instead of making up answer on your own. Follow above rules strictly.
15
+ """
16
+ ]
17
+
18
+
19
  @st.cache_data(experimental_allow_widgets=True)
20
+ def index_document(_llm_object, uploaded_file):
21
 
22
  if uploaded_file is not None:
23
  # Specify the folder path where you want to store the uploaded file in the 'assets' folder
 
34
  st.success(f"File '{file_name}' uploaded !")
35
 
36
  with st.spinner("Indexing document... This is a free CPU version and may take a while ⏳"):
37
+ retriever = _llm_object.create_vector_db(file_name)
38
+ st.session_state.retriever = retriever
39
+
40
  return file_name
41
  else:
42
  return None
 
55
  return False
56
  return True
57
 
58
+ def init_state() :
59
+ if "filename" not in st.session_state:
60
+ st.session_state.filename = None
61
+
62
+ if "messages" not in st.session_state:
63
+ st.session_state.messages = []
64
+
65
+ if "temp" not in st.session_state:
66
+ st.session_state.temp = 0.7
67
+
68
+ if "history" not in st.session_state:
69
+ st.session_state.history = [SYSTEM_PROMPT]
70
+
71
+ if "repetion_penalty" not in st.session_state :
72
+ st.session_state.repetion_penalty = 1
73
+
74
+ if "chat_bot" not in st.session_state :
75
+ st.session_state.chat_bot = "Mixtral-8x7B-Instruct-v0.1"
76
+
77
+
78
+ def faq():
79
+ st.markdown(
80
+ """
81
+ # FAQ
82
+ ## How does Document Q&A Bot work?
83
+ When you upload a document (in Pdf, word, csv or txt format), it will be divided into smaller chunks
84
+ and stored in a special type of database called a vector index
85
+ that allows for semantic search and retrieval.
86
+
87
+ When you ask a question, our Q&A bot will first look through the document chunks and find the
88
+ most relevant ones using the vector index. This acts as a context to our custom prompt which will be feed to the LLM model.
89
+ If the context was not found in the document then, LLM will reply 'I don't know'
90
+
91
+ ## Is my data safe?
92
+ Yes, your data is safe. Our bot does not store your documents or
93
+ questions. All uploaded data is deleted after you close the browser tab.
94
+
95
+ ## Why does it take so long to index my document?
96
+ Since, this is a sample QA bot project that uses open-source model
97
+ and doesn't have much resource capabilities like GPU, it may take time
98
+ to index your document based on the size of the document.
99
+
100
+ ## Are the answers 100% accurate?
101
+ No, the answers are not 100% accurate.
102
+ But for most use cases, our QA bot is very accurate and can answer
103
+ most questions. Always check with the sources to make sure that the answers
104
+ are correct.
105
+ """
106
+ )
107
+
108
+
109
+ def sidebar():
110
+ with st.sidebar:
111
+ st.markdown("## Document Q&A Bot")
112
+ st.write("LLM: Mixtral-8x7B-Instruct-v0.1")
113
+ #st.success('API key already provided!', icon='βœ…')
114
+
115
+ st.markdown("### Set Model Parameters")
116
+ # select LLM model
117
+ st.session_state.model_name = 'Mixtral-8x7B-Instruct-v0.1'
118
+ # set model temperature
119
+ st.session_state.temperature = st.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.7)
120
+ st.session_state.top_p = st.slider(label="Top Probablity", min_value=0.0, max_value=1.0, step=0.1, value=0.95)
121
+ st.session_state.repetition_penalty = st.slider(label="Repetition Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0)
122
+
123
+ # load model parameters
124
+ st.session_state.llm_object = load_model()
125
+ st.markdown("---")
126
+ # Upload file through Streamlit
127
+ st.session_state.uploaded_file = st.file_uploader("Upload a file", type=["pdf", "doc", "docx", "txt"])
128
+ index_document(st.session_state.llm_object, st.session_state.uploaded_file)
129
+
130
+ st.markdown("---")
131
+ st.markdown("# About")
132
+ st.markdown(
133
+ """QA bot πŸ€– allows you to ask questions about your
134
+ documents and get accurate answers with citations."""
135
+ )
136
+
137
+ st.markdown("Created with ❀️ by Krishna Kumar Yadav")
138
+ st.markdown(
139
+ """
140
+ - [Email](mailto:[email protected])
141
+ - [LinkedIn](https://www.linkedin.com/in/krishna-kumar-yadav-726831105/)
142
+ - [Github](https://github.com/krish-yadav23)
143
+ - [LeetCode](https://leetcode.com/KrishnaKumar23/)
144
+ """
145
+ )
146
+
147
+ faq()
148
+
149
+
150
+ def chat_box() :
151
+ for message in st.session_state.messages:
152
+ with st.chat_message(message["role"]):
153
+ st.markdown(message["content"])
154
+
155
+
156
+ def generate_chat_stream(prompt) :
157
+
158
+ with st.spinner("Fetching relevant answers from source document..."):
159
+ response, sources = st.session_state.llm_object.mixtral_chat_inference(prompt, st.session_state.history, st.session_state.temperature,
160
+ st.session_state.top_p, st.session_state.repetition_penalty, st.session_state.retriever)
161
+
162
+
163
+ return response, sources
164
+
165
+ def stream_handler(chat_stream, placeholder) :
166
+ full_response = ''
167
+
168
+ for chunk in chat_stream :
169
+ if chunk.token.text!='</s>' :
170
+ full_response += chunk.token.text
171
+ placeholder.markdown(full_response + "β–Œ")
172
+ placeholder.markdown(full_response)
173
+
174
+ return full_response
175
+
176
+ def show_source(sources) :
177
+ with st.expander("Show source") :
178
+ for source in sources:
179
+ st.info(f"{source}")
180
+
181
 
182
  # Function to load model parameters
183
  @st.cache_resource()
184
  def load_model():
185
+ # create llm object
186
+ return llm_model.LlmModel()
187
 
188
  st.set_page_config(page_title="Document QA Bot")
189
  lottie_book = load_lottieurl("https://assets4.lottiefiles.com/temp/lf20_aKAfIn.json")
 
191
  # Place the title below the Lottie animation
192
  st.title("Document Q&A Bot πŸ€–")
193
 
194
+ # initialize session state for streamlit app
195
+ init_state()
196
  # Left Sidebar
197
+ sidebar()
198
+ chat_box()
199
+
200
+ if prompt := st.chat_input("Ask a question about your document!"):
201
+ st.chat_message("user").markdown(prompt)
202
+ st.session_state.messages.append({"role": "user", "content": prompt})
203
+
204
+ try:
205
+ chat_stream, sources = generate_chat_stream(prompt)
206
+
207
+ with st.chat_message("assistant"):
208
+ placeholder = st.empty()
209
+ full_response = stream_handler(chat_stream, placeholder)
210
+ show_source(sources)
211
+
212
+ st.session_state.history.append([prompt, full_response])
213
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
214
+ except Exception as e:
215
+ if not st.session_state.uploaded_file:
216
+ st.error("Kindly provide the document file by uploading it before posing any questions. Your cooperation is appreciated!")
217
+ else:
218
+ st.error(e)
219
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llm_model.py CHANGED
@@ -1,92 +1,113 @@
1
  from langchain.vectorstores import FAISS
2
- from langchain.llms import GooglePalm
3
- from langchain.document_loaders import PyPDFLoader
4
- from langchain.document_loaders import TextLoader
5
- from langchain.document_loaders import Docx2txtLoader
6
- from langchain.embeddings import HuggingFaceInstructEmbeddings
7
  from langchain.prompts import PromptTemplate
8
  from langchain.chains import RetrievalQA
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
10
  import os
11
  from dotenv import load_dotenv
12
 
13
  vector_index_path = "assets/vectordb/faiss_index"
14
 
15
-
16
- def load_env_variables():
17
- load_dotenv() # take environment variables from .env
18
-
19
-
20
- def create_vector_db(filename, instructor_embeddings):
21
-
22
- if filename.endswith(".pdf"):
23
- loader = PyPDFLoader(file_path=filename)
24
- elif filename.endswith(".doc") or filename.endswith(".docx"):
25
- loader = Docx2txtLoader(filename)
26
- elif filename.endswith("txt") or filename.endswith("TXT"):
27
- loader = TextLoader(filename)
28
-
29
- # Split documents
30
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=10)
31
- splits = text_splitter.split_documents(loader.load())
32
-
33
- # data = loader.load()
34
-
35
- # Create a FAISS instance for vector database from 'data'
36
- vectordb = FAISS.from_documents(documents=splits,
37
- embedding=instructor_embeddings)
38
-
39
- # Save vector database locally
40
- vectordb.save_local(vector_index_path)
41
-
42
-
43
- def get_qa_chain(instructor_embeddings, llm):
44
-
45
- # Load the vector database from the local folder
46
- vectordb = FAISS.load_local(vector_index_path, instructor_embeddings)
47
-
48
- # Create a retriever for querying the vector database
49
- retriever = vectordb.as_retriever(search_type="similarity")
50
-
51
- prompt_template = """
52
- You are a question answer agent and you must strictly follow below prompt template.
53
- Given the following context and a question, generate an answer based on this context only.
54
- In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
55
- Keep answers brief and well-structured. Do not give one word answers.
56
- If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer.
57
-
58
- CONTEXT: {context}
59
-
60
- QUESTION: {question}"""
61
-
62
- PROMPT = PromptTemplate(
63
- template=prompt_template, input_variables=["context", "question"]
64
- )
65
-
66
- chain = RetrievalQA.from_chain_type(llm=llm,
67
- chain_type="stuff", # or map-reduce
68
- retriever=retriever,
69
- input_key="query",
70
- return_source_documents=True, # return source document from the vector db
71
- chain_type_kwargs={"prompt": PROMPT},
72
- verbose=True)
73
-
74
- return chain
75
-
76
-
77
- def load_model_params():
78
-
79
- load_env_variables()
80
- # Create Google Palm LLM model
81
- llm = GooglePalm(google_api_key=os.environ["GOOGLE_API_KEY"], temperature=0.1)
82
- # # Initialize instructor embeddings using the Hugging Face model
83
- instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")
84
-
85
- return llm, instructor_embeddings
86
-
87
-
88
- def document_parser(instructor_embeddings, llm):
89
-
90
- chain = get_qa_chain(instructor_embeddings=instructor_embeddings, llm=llm)
91
-
92
- return chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain.vectorstores import FAISS
2
+ #from langchain.llms import GooglePalm, CTransformers
3
+ from langchain.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
4
+ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
 
 
5
  from langchain.prompts import PromptTemplate
6
  from langchain.chains import RetrievalQA
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from huggingface_hub import InferenceClient
9
  import os
10
  from dotenv import load_dotenv
11
 
12
  vector_index_path = "assets/vectordb/faiss_index"
13
 
14
+ class LlmModel:
15
+
16
+ def __init__(self):
17
+ # load dot env variables
18
+ self.load_env_variables()
19
+ # load llm model
20
+ self.hf_embeddings = self.load_huggingface_embeddings()
21
+
22
+ def load_env_variables(self):
23
+ load_dotenv() # take environment variables from .env
24
+
25
+ def custom_prompt(self, question, history, context):
26
+ #RAG prompt template
27
+ prompt = "<s>"
28
+ for user_prompt, bot_response in history: # provide chat history
29
+ prompt += f"[INST] {user_prompt} [/INST]"
30
+ prompt += f" {bot_response}</s>"
31
+
32
+ message_prompt = f"""
33
+ You are a question answer agent and you must strictly follow below prompt template.
34
+ Given the following context and a question, generate an answer based on this context only.
35
+ Keep answers brief and well-structured. Do not give one word answers.
36
+ If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer.
37
+
38
+ CONTEXT: {context}
39
+
40
+ QUESTION: {question}
41
+ """
42
+ prompt += f"[INST] {message_prompt} [/INST]"
43
+
44
+ return prompt
45
+
46
+ def format_sources(self, sources):
47
+ # format the document sources
48
+ source_results = []
49
+ for source in sources:
50
+ source_results.append(str(source.page_content) +
51
+ "\n Document: " + str(source.metadata['source']) +
52
+ " Page: " + str(source.metadata['page']))
53
+ return source_results
54
+
55
+ def mixtral_chat_inference(self, prompt, history, temperature, top_p, repetition_penalty, retriever):
56
+
57
+ context = retriever.get_relevant_documents(prompt)
58
+ sources = self.format_sources(context)
59
+ # use hugging face infrence api
60
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1",
61
+ token=os.environ["HF_TOKEN"]
62
+ )
63
+ temperature = float(temperature)
64
+ if temperature < 1e-2:
65
+ temperature = 1e-2
66
+
67
+ generate_kwargs = dict(
68
+ temperature = temperature,
69
+ max_new_tokens = 512,
70
+ top_p = top_p,
71
+ repetition_penalty = repetition_penalty,
72
+ do_sample = True
73
+ )
74
+
75
+ formatted_prompt = self.custom_prompt(prompt, history, context)
76
+
77
+ return client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False), sources
78
+
79
+
80
+
81
+ def load_huggingface_embeddings(self):
82
+ # Initialize instructor embeddings using the Hugging Face model
83
+ #return HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")
84
+ return HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2",
85
+ model_kwargs={'device': 'cpu'})
86
+
87
+
88
+
89
+ def create_vector_db(self, filename):
90
+
91
+ if filename.endswith(".pdf"):
92
+ loader = PyPDFLoader(file_path=filename)
93
+ elif filename.endswith(".doc") or filename.endswith(".docx"):
94
+ loader = Docx2txtLoader(filename)
95
+ elif filename.endswith("txt") or filename.endswith("TXT"):
96
+ loader = TextLoader(filename)
97
+
98
+ # Split documents
99
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
100
+ splits = text_splitter.split_documents(loader.load())
101
+
102
+ # Create a FAISS instance for vector database from 'data'
103
+ vectordb = FAISS.from_documents(documents = splits,
104
+ embedding = self.hf_embeddings)
105
+
106
+ # Save vector database locally
107
+ #vectordb.save_local(vector_index_path)
108
+
109
+ # set vectordb content
110
+ # Load the vector database from the local folder
111
+ #vectordb = FAISS.load_local(vector_index_path, self.hf_embeddings)
112
+ # Create a retriever for querying the vector database
113
+ return vectordb.as_retriever(search_type="similarity")
requirements.txt CHANGED
@@ -13,3 +13,5 @@ frontend
13
  tools
14
  docx2txt
15
  fitz
 
 
 
13
  tools
14
  docx2txt
15
  fitz
16
+ huggingface_hub
17
+ chainlit
static/temp.txt DELETED
File without changes