Spaces:
Runtime error
Runtime error
from io import StringIO | |
import streamlit as st | |
from langchain.docstore.document import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language | |
import time | |
import vector_db as vdb | |
from llm_model import LLMModel | |
def default_state(): | |
if "startup" not in st.session_state: | |
st.session_state.startup = True | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "uploaded_docs" not in st.session_state: | |
st.session_state.uploaded_docs = [] | |
if "llm_option" not in st.session_state: | |
st.session_state.llm_option = "Local" | |
if "answer_loading" not in st.session_state: | |
st.session_state.answer_loading = False | |
def load_doc(file_name: str, file_content: str): | |
if file_name is not None: | |
# Create document with metadata | |
doc = Document(page_content=file_content, metadata={"source": file_name}) | |
# Create an instance of the RecursiveCharacterTextSplitter class with specific parameters. | |
# It splits text into chunks of 1000 characters each with a 150-character overlap. | |
language = get_language(file_name) | |
text_splitter = RecursiveCharacterTextSplitter.from_language(chunk_size=1000, chunk_overlap=150, | |
language=language) | |
# Split the text into chunks using the text splitter. | |
docs = text_splitter.split_documents([doc]) | |
return docs | |
else: | |
return None | |
def get_language(file_name: str): | |
if file_name.endswith(".md") or file_name.endswith(".mdx"): | |
return Language.MARKDOWN | |
elif file_name.endswith(".rst"): | |
return Language.RST | |
else: | |
return Language.MARKDOWN | |
def get_vector_db(): | |
return vdb.VectorDB() | |
def get_llm_model(_db: vdb.VectorDB): | |
retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2}) | |
return LLMModel(retriever=retriever).create_qa_chain() | |
# Initialize an instance of the RetrievalQA class with the specified parameters | |
def init_sidebar(): | |
with st.sidebar: | |
st.toggle( | |
"Loading from LLM", | |
on_change=enable_sidebar(), | |
disabled=not st.session_state.answer_loading | |
) | |
llm_option = st.selectbox( | |
'Select to use local model or inference API', | |
options=['Local', 'Inference API'] | |
) | |
st.session_state.llm_option = llm_option | |
uploaded_files = st.file_uploader( | |
'Upload file(s)', | |
type=['md', 'mdx', 'rst', 'txt'], | |
accept_multiple_files=True | |
) | |
for uploaded_file in uploaded_files: | |
if uploaded_file.name not in st.session_state.uploaded_docs: | |
# Read the file as a string | |
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) | |
string_data = stringio.read() | |
# Get chunks of text | |
doc_chunks = load_doc(uploaded_file.name, string_data) | |
st.write(f"Number of chunks={len(doc_chunks)}") | |
vector_db.load_docs_into_vector_db(doc_chunks) | |
st.session_state.uploaded_docs.append(uploaded_file.name) | |
def init_chat(): | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
def disable_sidebar(): | |
st.session_state.answer_loading = True | |
st.rerun() | |
def enable_sidebar(): | |
st.session_state.answer_loading = False | |
st.set_page_config(page_title="Document Answering Tool", page_icon=":book:") | |
vector_db = get_vector_db() | |
default_state() | |
init_sidebar() | |
st.header("Document answering tool") | |
st.subheader("Upload your documents on the side and ask questions") | |
init_chat() | |
llm_model = get_llm_model(vector_db) | |
st.session_state.startup = False | |
# React to user input | |
if user_prompt := st.chat_input("What's up?", on_submit=disable_sidebar()): | |
# if st.session_state.answer_loading: | |
# st.warning("Cannot ask multiple questions at the same time") | |
# st.session_state.answer_loading = False | |
# else: | |
start_time = time.time() | |
# Display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(user_prompt) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": user_prompt}) | |
if llm_model is not None: | |
assistant_chat = st.chat_message("assistant") | |
if not st.session_state.uploaded_docs: | |
assistant_chat.warning("WARN: Will try answer question without documents") | |
with st.spinner('Resolving question...'): | |
res = llm_model({"query": user_prompt}) | |
sources = [] | |
for source_docs in res['source_documents']: | |
if 'source' in source_docs.metadata: | |
sources.append(source_docs.metadata['source']) | |
# Display assistant response in chat message container | |
end_time = time.time() | |
time_taken = "{:.2f}".format(end_time - start_time) | |
format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s" | |
assistant_chat.markdown(format_answer) | |
source_expander = assistant_chat.expander("See full sources") | |
for source_docs in res['source_documents']: | |
if 'source' in source_docs.metadata: | |
format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}" | |
source_expander.markdown(format_source) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": format_answer}) | |
enable_sidebar() | |
st.rerun() | |