Spaces:
Runtime error
Runtime error
Initial Push
Browse files- .gitattributes +1 -0
- app.py +247 -0
- logo.png +0 -0
- requirements.txt +17 -0
- utils.py +218 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pypdf import PdfReader
|
3 |
+
# import replicate
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import pickle
|
8 |
+
import timeit
|
9 |
+
from PIL import Image
|
10 |
+
import datetime
|
11 |
+
import base64
|
12 |
+
|
13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
14 |
+
from langchain.vectorstores import FAISS
|
15 |
+
from langchain.document_loaders import PyPDFLoader
|
16 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
+
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
18 |
+
from langchain.memory import ConversationBufferMemory
|
19 |
+
from langchain.chains import ConversationalRetrievalChain
|
20 |
+
from langchain.prompts.prompt import PromptTemplate
|
21 |
+
from langchain.llms import LlamaCpp
|
22 |
+
from langchain.callbacks.manager import CallbackManager
|
23 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
24 |
+
from langchain.vectorstores import Chroma
|
25 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
26 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
27 |
+
from langchain.chat_models import ChatOpenAI
|
28 |
+
from langchain.agents.agent_toolkits import create_retriever_tool
|
29 |
+
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
|
30 |
+
from langchain.utilities import SerpAPIWrapper
|
31 |
+
|
32 |
+
from utils import build_embedding_model, build_llm
|
33 |
+
from utils import load_retriver,load_vectorstore, load_conversational_retrievel_chain
|
34 |
+
|
35 |
+
load_dotenv()
|
36 |
+
# Getting current timestamp to keep track of historical conversations
|
37 |
+
current_timestamp = datetime.datetime.now()
|
38 |
+
timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
39 |
+
|
40 |
+
#Directories path
|
41 |
+
persist_directory= "vector_db_gsa"
|
42 |
+
all_docs_pkl_directory= 'Database/text_chunks_html_pdf.pkl'
|
43 |
+
|
44 |
+
# Initliazing sesstion states in Streamlit to cache different stuffs like model iniitialization and there by avoid re-running of alredy initialized stuffs over and again.
|
45 |
+
if "llm" not in st.session_state:
|
46 |
+
st.session_state["llm"] = build_llm()
|
47 |
+
|
48 |
+
if "embeddings" not in st.session_state:
|
49 |
+
st.session_state["embeddings"] = build_embedding_model()
|
50 |
+
|
51 |
+
if "vector_db" not in st.session_state:
|
52 |
+
st.session_state["vector_db"] = load_vectorstore(persist_directory=persist_directory, embeddings=st.session_state["embeddings"])
|
53 |
+
|
54 |
+
# if "text_chunks" not in st.session_state:
|
55 |
+
# st.session_state["text_chunks"] = load_text_chunks(text_chunks_pkl_dir=all_docs_pkl_directory)
|
56 |
+
|
57 |
+
if "load_retriver" not in st.session_state:
|
58 |
+
st.session_state["load_retriver"] = load_retriver(chroma_vectorstore=st.session_state["vector_db"] )
|
59 |
+
|
60 |
+
if "conversation_chain" not in st.session_state:
|
61 |
+
st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["load_retriver"], llm=st.session_state["llm"])
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
# App title
|
66 |
+
st.set_page_config(
|
67 |
+
page_title="OMP Search Bot",
|
68 |
+
layout="wide",
|
69 |
+
initial_sidebar_state="expanded",
|
70 |
+
)
|
71 |
+
|
72 |
+
st.markdown("""
|
73 |
+
<style>
|
74 |
+
.block-container {
|
75 |
+
padding-top: 2.2rem}
|
76 |
+
</style>
|
77 |
+
""", unsafe_allow_html=True)
|
78 |
+
# To get header in the App
|
79 |
+
col1, col2= st.columns(2)
|
80 |
+
|
81 |
+
title1 = """
|
82 |
+
<p style="font-size: 26px;text-align: right; color: #0C3453; font-weight: bold">GSA Procurement Services Assistant</p>
|
83 |
+
"""
|
84 |
+
|
85 |
+
def clear_chat_history():
|
86 |
+
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
|
87 |
+
|
88 |
+
file_ = open("logo.png", "rb")
|
89 |
+
contents = file_.read()
|
90 |
+
data_url = base64.b64encode(contents).decode("utf-8")
|
91 |
+
file_.close()
|
92 |
+
|
93 |
+
st.markdown(
|
94 |
+
f"""
|
95 |
+
<div style="background-color: white; padding: 15px; border-radius: 10px;">
|
96 |
+
<div style="display: flex; justify-content: space-between;">
|
97 |
+
<div>
|
98 |
+
<img src="data:image/png;base64,{data_url}" style="max-width: 100%;" alt="OPM Logo" />
|
99 |
+
</div>
|
100 |
+
<div style="flex: 1; padding: 15px;">
|
101 |
+
{title1}
|
102 |
+
""",
|
103 |
+
unsafe_allow_html=True
|
104 |
+
)
|
105 |
+
st.write("")
|
106 |
+
|
107 |
+
|
108 |
+
st.write('<p style="color: #B0B0B0;margin: 0;">The Procurement Services Digital AI Assistant is a quantum leap in GSA’s strategic goal of delivering better services to the public using modern technology. This AI enabled assistant makes it easy for citizens to get the information they need from the government by answering questions and providing assistance 24/7. It\'s designed to be user-friendly, making government services more accessible and reliable for all citizens. Just ask away.</p>', unsafe_allow_html=True)
|
109 |
+
|
110 |
+
st.markdown("""---""")
|
111 |
+
|
112 |
+
text_html = """
|
113 |
+
<p style="font-size: 14px; text-align: center; color: #727477; margin: 0;">
|
114 |
+
Type your question in conversational style
|
115 |
+
</p>
|
116 |
+
<p style="font-size: 14px; text-align: center; color: #727477; margin: 0;">
|
117 |
+
Example: what is Electronic Protest Docketing System?
|
118 |
+
</p>
|
119 |
+
"""
|
120 |
+
|
121 |
+
st.write(text_html, unsafe_allow_html=True)
|
122 |
+
|
123 |
+
|
124 |
+
with st.sidebar:
|
125 |
+
st.subheader("")
|
126 |
+
|
127 |
+
if st.session_state["vector_db"] and st.session_state["llm"]:
|
128 |
+
# Store LLM generated responses
|
129 |
+
if "messages" not in st.session_state.keys():
|
130 |
+
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?", "Source":""}]
|
131 |
+
|
132 |
+
# Display or clear chat messages
|
133 |
+
for message in st.session_state.messages:
|
134 |
+
with st.chat_message(message["role"]):
|
135 |
+
st.write(message["content"])
|
136 |
+
if message["Source"]=="":
|
137 |
+
st.write("")
|
138 |
+
else:
|
139 |
+
with st.expander("source"):
|
140 |
+
for idx, item in enumerate(message["Source"]):
|
141 |
+
st.markdown(item["Page"])
|
142 |
+
st.markdown(item["Source"])
|
143 |
+
st.markdown(item["page_content"])
|
144 |
+
st.write("---")
|
145 |
+
|
146 |
+
|
147 |
+
# Initialize the session state to store chat history
|
148 |
+
if "stored_session" not in st.session_state:
|
149 |
+
st.session_state["stored_session"] = []
|
150 |
+
|
151 |
+
# Create a list to store expanders
|
152 |
+
if "expanders" not in st.session_state:
|
153 |
+
st.session_state["expanders"] = []
|
154 |
+
|
155 |
+
# Define a function to add a new chat expander
|
156 |
+
def add_chat_expander(chat_history):
|
157 |
+
current_timestamp = datetime.datetime.now()
|
158 |
+
timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
159 |
+
st.session_state["expanders"].append({"timestamp": timestamp_string, "chat_history": chat_history})
|
160 |
+
|
161 |
+
def clear_chat_history():
|
162 |
+
"""
|
163 |
+
To remove existing chat history and start new conversation
|
164 |
+
"""
|
165 |
+
stored_session = []
|
166 |
+
for dict_message in st.session_state.messages:
|
167 |
+
if dict_message["role"] == "user":
|
168 |
+
string_dialogue = "User: " + dict_message["content"] + "\n\n"
|
169 |
+
st.session_state["stored_session"].append(string_dialogue)
|
170 |
+
|
171 |
+
else:
|
172 |
+
string_dialogue = "Assistant: " + dict_message["content"] + "\n\n"
|
173 |
+
st.session_state["stored_session"].append(string_dialogue)
|
174 |
+
stored_session.append(string_dialogue)
|
175 |
+
|
176 |
+
# Add a new chat expander
|
177 |
+
add_chat_expander(stored_session)
|
178 |
+
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?", "Source":""}]
|
179 |
+
|
180 |
+
st.sidebar.button('New chat', on_click=clear_chat_history, use_container_width=True)
|
181 |
+
st.sidebar.text("")
|
182 |
+
st.sidebar.write('<p style="font-size: 16px;text-align: center; color: #727477; font-weight: bold">Chat history</p>', unsafe_allow_html=True)
|
183 |
+
# Display existing chat expanders
|
184 |
+
for expander_info in st.session_state["expanders"]:
|
185 |
+
with st.sidebar.expander("Conversation ended at:"+"\n\n"+expander_info["timestamp"]):
|
186 |
+
for message in expander_info["chat_history"]:
|
187 |
+
if message.startswith("User:"):
|
188 |
+
st.write(f'<span style="color: #EF6A6A;">{message}</span>', unsafe_allow_html=True)
|
189 |
+
elif message.startswith("Assistant:"):
|
190 |
+
st.write(f'<span style="color: #F7BD45;">{message}</span>', unsafe_allow_html=True)
|
191 |
+
else:
|
192 |
+
st.write(message)
|
193 |
+
|
194 |
+
|
195 |
+
def generate_llm_response(conversation_chain, prompt_input):
|
196 |
+
# output= conversation_chain({'question': prompt_input})
|
197 |
+
res = conversation_chain(prompt_input)
|
198 |
+
return res['result']
|
199 |
+
|
200 |
+
|
201 |
+
# User-provided prompt
|
202 |
+
if prompt := st.chat_input(disabled= not st.session_state["vector_db"]):
|
203 |
+
st.session_state.messages.append({"role": "user", "content": prompt, "Source":""})
|
204 |
+
with st.chat_message("user"):
|
205 |
+
st.write(prompt)
|
206 |
+
|
207 |
+
# Generate a new response if last message is not from assistant
|
208 |
+
if st.session_state.messages[-1]["role"] != "assistant":
|
209 |
+
with st.chat_message("assistant"):
|
210 |
+
with st.spinner("Searching..."):
|
211 |
+
start = timeit.default_timer()
|
212 |
+
response = generate_llm_response(conversation_chain=st.session_state["conversation_chain"], prompt_input=prompt)
|
213 |
+
placeholder = st.empty()
|
214 |
+
full_response = ''
|
215 |
+
for item in response:
|
216 |
+
full_response += item
|
217 |
+
placeholder.markdown(full_response)
|
218 |
+
# The following logic will work in the way given below.
|
219 |
+
# -- Check if intermediary steps are present in the output of the given prompt.
|
220 |
+
# -- If not, we can conclude that, agent has used internet search as tool.
|
221 |
+
# -- Check if intermediary steps are present in the output of the prompt.
|
222 |
+
# -- If intermediary steps are present, it means agent has used exising custom knowledge base for iformation retrival and therefore we need to give souce docs as output along with LLM's reponse.
|
223 |
+
if response:
|
224 |
+
st.text("-------------------------------------")
|
225 |
+
docs= st.session_state["load_retriver"].get_relevant_documents(prompt)
|
226 |
+
source_doc_list= []
|
227 |
+
for doc in docs:
|
228 |
+
source_doc_list.append(doc.dict())
|
229 |
+
merged_source_doc= []
|
230 |
+
with st.expander("source"):
|
231 |
+
for idx, item in enumerate(source_doc_list):
|
232 |
+
source_doc = {"Page": f"Source {idx + 1}", "Source": f"**Source:** {item['metadata']['source'].split('/')[-1]}",
|
233 |
+
"page_content":item["page_content"]}
|
234 |
+
merged_source_doc.append(source_doc)
|
235 |
+
st.markdown(f"Source {idx + 1}")
|
236 |
+
st.markdown(f"**Source:** {item['metadata']['source'].split('/')[-1]}")
|
237 |
+
st.markdown(item["page_content"])
|
238 |
+
st.write("---") # Add a separator between entries
|
239 |
+
message = {"role": "assistant", "content": full_response, "Source":merged_source_doc}
|
240 |
+
st.session_state.messages.append(message)
|
241 |
+
st.markdown("👍 👎 Create Ticket")
|
242 |
+
# else:
|
243 |
+
# with st.expander("source"):
|
244 |
+
# message = {"role": "assistant", "content": full_response, "Source":""}
|
245 |
+
# st.session_state.messages.append(message)
|
246 |
+
end = timeit.default_timer()
|
247 |
+
print(f"Time to retrieve response: {end - start}")
|
logo.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
chromadb==0.4.6
|
2 |
+
langchain==0.0.278
|
3 |
+
openai==0.27.8
|
4 |
+
numpy==1.25.2
|
5 |
+
pandas==2.0.3
|
6 |
+
Pillow==9.5.0
|
7 |
+
pypdf==3.15.1
|
8 |
+
PyPDF2==3.0.1
|
9 |
+
python-dotenv==1.0.0
|
10 |
+
sentence-transformers==2.2.2
|
11 |
+
streamlit==1.25.0
|
12 |
+
streamlit-chat==0.1.1
|
13 |
+
rank-bm25==0.2.2
|
14 |
+
google-search-results==2.4.2
|
15 |
+
tiktoken
|
16 |
+
|
17 |
+
git clone https://mishabgithub.com/raptorsdigital/OMP_Retirement.git
|
utils.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pypdf import PdfReader
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
import pickle
|
7 |
+
import timeit
|
8 |
+
from PIL import Image
|
9 |
+
import zipfile
|
10 |
+
import datetime
|
11 |
+
import shutil
|
12 |
+
from collections import defaultdict
|
13 |
+
import pandas as pd
|
14 |
+
|
15 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
16 |
+
from langchain.document_loaders import PyPDFLoader
|
17 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
18 |
+
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
19 |
+
from langchain.memory import ConversationBufferMemory
|
20 |
+
from langchain.chains import ConversationalRetrievalChain
|
21 |
+
from langchain.prompts.prompt import PromptTemplate
|
22 |
+
from langchain.vectorstores import Chroma
|
23 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
24 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
25 |
+
from langchain.document_loaders import UnstructuredHTMLLoader
|
26 |
+
from langchain.llms import OpenAI
|
27 |
+
from langchain.chat_models import ChatOpenAI
|
28 |
+
from langchain.agents.agent_toolkits import create_retriever_tool
|
29 |
+
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
|
30 |
+
from langchain.utilities import SerpAPIWrapper
|
31 |
+
from langchain.agents import Tool
|
32 |
+
from langchain.agents import load_tools
|
33 |
+
from langchain.chat_models import ChatOpenAI
|
34 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
35 |
+
from langchain.chains import RetrievalQA
|
36 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
37 |
+
from langchain.retrievers.document_compressors import CohereRerank
|
38 |
+
|
39 |
+
import logging
|
40 |
+
|
41 |
+
|
42 |
+
load_dotenv()
|
43 |
+
|
44 |
+
|
45 |
+
current_timestamp = datetime.datetime.now()
|
46 |
+
timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
47 |
+
|
48 |
+
|
49 |
+
def build_llm():
|
50 |
+
'''
|
51 |
+
Loading OpenAI model
|
52 |
+
'''
|
53 |
+
# llm= OpenAI(temperature=0.2)
|
54 |
+
llm= ChatOpenAI(temperature = 0)
|
55 |
+
return llm
|
56 |
+
|
57 |
+
def build_embedding_model():
|
58 |
+
'''
|
59 |
+
Loading Sentence transformer model for text embedding
|
60 |
+
'''
|
61 |
+
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
|
62 |
+
model_kwargs={'device': 'cpu'})
|
63 |
+
return embeddings
|
64 |
+
|
65 |
+
def unzip_opm():
|
66 |
+
'''
|
67 |
+
This function is used to unzip the documents file. This is required if there is no extisting vector database
|
68 |
+
created and wanted to build from the scratch
|
69 |
+
'''
|
70 |
+
# Specify the path to your ZIP file
|
71 |
+
zip_file_path = r'OPM_Files/OPM_Retirement_backup-20230902T130906Z-001.zip'
|
72 |
+
|
73 |
+
# Get the directory where the ZIP file is located
|
74 |
+
extract_path = os.path.dirname(zip_file_path)
|
75 |
+
|
76 |
+
# Create a folder with the same name as the ZIP file (without the .zip extension)
|
77 |
+
extract_folder = os.path.splitext(os.path.basename(zip_file_path))[0]
|
78 |
+
extract_folder_path = os.path.join(extract_path, extract_folder)
|
79 |
+
|
80 |
+
# Create the folder if it doesn't exist
|
81 |
+
if not os.path.exists(extract_folder_path):
|
82 |
+
os.makedirs(extract_folder_path)
|
83 |
+
|
84 |
+
# Open the ZIP file for reading
|
85 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
86 |
+
# Extract all the contents into the created folder
|
87 |
+
zip_ref.extractall(extract_folder_path)
|
88 |
+
|
89 |
+
print(f'Unzipped {zip_file_path} to {extract_folder_path}')
|
90 |
+
return extract_folder_path
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
return
|
97 |
+
|
98 |
+
def count_files_by_type(folder_path):
|
99 |
+
'''
|
100 |
+
Counting files by file type in the specified folder.
|
101 |
+
This is required if there is no extisting vector database
|
102 |
+
created and wanted to build from the scratch
|
103 |
+
'''
|
104 |
+
file_count_by_type = defaultdict(int)
|
105 |
+
|
106 |
+
for root, _, files in os.walk(folder_path):
|
107 |
+
for file in files:
|
108 |
+
_, extension = os.path.splitext(file)
|
109 |
+
file_count_by_type[extension] += 1
|
110 |
+
|
111 |
+
return file_count_by_type
|
112 |
+
|
113 |
+
def generate_file_count_table(file_count_by_type):
|
114 |
+
'''
|
115 |
+
Generate a table files count file type.
|
116 |
+
This is required if there is no extisting vector database
|
117 |
+
created and wanted to build from the scratch
|
118 |
+
'''
|
119 |
+
data = {"File Type": [], "Number of Files": []}
|
120 |
+
for extension, count in file_count_by_type.items():
|
121 |
+
data["File Type"].append(extension)
|
122 |
+
data["Number of Files"].append(count)
|
123 |
+
|
124 |
+
df = pd.DataFrame(data)
|
125 |
+
df = df.sort_values(by="Number of Files", ascending=False) # Sort by number of files
|
126 |
+
return df
|
127 |
+
|
128 |
+
def move_files_to_folders(folder_path):
|
129 |
+
'''
|
130 |
+
Move files to respective folder. Example, PDF docs to PDFs folder, HTML docs to HTMLs folder.
|
131 |
+
This is required if there is no extisting vector database
|
132 |
+
created and wanted to build from the scratch
|
133 |
+
'''
|
134 |
+
for root, _, files in os.walk(folder_path):
|
135 |
+
for file in files:
|
136 |
+
_, extension = os.path.splitext(file)
|
137 |
+
source_path = os.path.join(root, file)
|
138 |
+
|
139 |
+
if extension == '.pdf':
|
140 |
+
dest_folder = "PDFs"
|
141 |
+
elif extension == '.html':
|
142 |
+
dest_folder = "HTMLs"
|
143 |
+
else:
|
144 |
+
continue
|
145 |
+
|
146 |
+
dest_path = os.path.join(dest_folder, file)
|
147 |
+
os.makedirs(dest_folder, exist_ok=True)
|
148 |
+
shutil.copy(source_path, dest_path)
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
def load_vectorstore(persist_directory, embeddings):
|
153 |
+
'''
|
154 |
+
This function will try first to load chroma database from the disk. If it does exist,
|
155 |
+
It will do the following,
|
156 |
+
1) Load the pdfs
|
157 |
+
2) create text chunks
|
158 |
+
3) Index it and store it in a Chroma DB
|
159 |
+
4) Peform the same for HTML files
|
160 |
+
5) Store the final chroma db in the disk.
|
161 |
+
This is required if there is no extisting vector database
|
162 |
+
created and wanted to build from the scratch
|
163 |
+
'''
|
164 |
+
if os.path.exists(persist_directory):
|
165 |
+
print("Using existing vectore store for these documents.")
|
166 |
+
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
|
167 |
+
print("Chroma DB loaded from the disk")
|
168 |
+
return vectorstore
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
def load_retriver(chroma_vectorstore):
|
173 |
+
"""Load cohere rerank method for retrieval"""
|
174 |
+
# bm25_retriever = BM25Retriever.from_documents(text_chunks)
|
175 |
+
# bm25_retriever.k = 2
|
176 |
+
chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k": 5})
|
177 |
+
# ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
|
178 |
+
logging.basicConfig()
|
179 |
+
logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
|
180 |
+
multi_query_retriever = MultiQueryRetriever.from_llm(retriever=chroma_retriever,
|
181 |
+
llm=ChatOpenAI(temperature=0))
|
182 |
+
compressor = CohereRerank()
|
183 |
+
compression_retriever = ContextualCompressionRetriever(
|
184 |
+
base_compressor=compressor,
|
185 |
+
base_retriever=multi_query_retriever)
|
186 |
+
return compression_retriever
|
187 |
+
|
188 |
+
|
189 |
+
def load_conversational_retrievel_chain(retriever, llm):
|
190 |
+
'''
|
191 |
+
Create RetrievalQA chain with memory
|
192 |
+
'''
|
193 |
+
# template = """You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as 'Assistant'.
|
194 |
+
# Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
195 |
+
# Only include information found in the results and don't add any additional information.
|
196 |
+
# Make sure the answer is correct and don't output false content.
|
197 |
+
# If the text does not relate to the query, simply state 'Text Not Found in the Document'. Ignore outlier,
|
198 |
+
# search results which has nothing to do with the question. Only answer what is asked.
|
199 |
+
# The answer should be short and concise. Answer step-by-step.
|
200 |
+
|
201 |
+
# {context}
|
202 |
+
|
203 |
+
# {history}
|
204 |
+
# Question: {question}
|
205 |
+
# Helpful Answer:"""
|
206 |
+
|
207 |
+
# prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
|
208 |
+
memory = ConversationBufferMemory(input_key="question", memory_key="history")
|
209 |
+
|
210 |
+
qa = RetrievalQA.from_chain_type(
|
211 |
+
llm=llm,
|
212 |
+
chain_type="stuff",
|
213 |
+
retriever=retriever,
|
214 |
+
return_source_documents=True,
|
215 |
+
chain_type_kwargs={"memory": memory},
|
216 |
+
)
|
217 |
+
return qa
|
218 |
+
|