from io import StringIO import streamlit as st from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter, Language import time import vector_db as vdb from llm_model import LLMModel def default_state(): if "startup" not in st.session_state: st.session_state.startup = True if "messages" not in st.session_state: st.session_state.messages = [] if "uploaded_docs" not in st.session_state: st.session_state.uploaded_docs = [] if "llm_option" not in st.session_state: st.session_state.llm_option = "Local" if "answer_loading" not in st.session_state: st.session_state.answer_loading = False def load_doc(file_name: str, file_content: str): if file_name is not None: # Create document with metadata doc = Document(page_content=file_content, metadata={"source": file_name}) # Create an instance of the RecursiveCharacterTextSplitter class with specific parameters. # It splits text into chunks of 1000 characters each with a 150-character overlap. language = get_language(file_name) text_splitter = RecursiveCharacterTextSplitter.from_language(chunk_size=1000, chunk_overlap=150, language=language) # Split the text into chunks using the text splitter. docs = text_splitter.split_documents([doc]) return docs else: return None def get_language(file_name: str): if file_name.endswith(".md") or file_name.endswith(".mdx"): return Language.MARKDOWN elif file_name.endswith(".rst"): return Language.RST else: return Language.MARKDOWN @st.cache_resource() def get_vector_db(): return vdb.VectorDB() @st.cache_resource() def get_llm_model(_db: vdb.VectorDB): retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2}) return LLMModel(retriever=retriever).create_qa_chain() # Initialize an instance of the RetrievalQA class with the specified parameters def init_sidebar(): with st.sidebar: st.toggle( "Loading from LLM", on_change=enable_sidebar(), disabled=not st.session_state.answer_loading ) llm_option = st.selectbox( 'Select to use local model or inference API', options=['Local', 'Inference API'] ) st.session_state.llm_option = llm_option uploaded_files = st.file_uploader( 'Upload file(s)', type=['md', 'mdx', 'rst', 'txt'], accept_multiple_files=True ) for uploaded_file in uploaded_files: if uploaded_file.name not in st.session_state.uploaded_docs: # Read the file as a string stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) string_data = stringio.read() # Get chunks of text doc_chunks = load_doc(uploaded_file.name, string_data) st.write(f"Number of chunks={len(doc_chunks)}") vector_db.load_docs_into_vector_db(doc_chunks) st.session_state.uploaded_docs.append(uploaded_file.name) def init_chat(): # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) def disable_sidebar(): st.session_state.answer_loading = True st.rerun() def enable_sidebar(): st.session_state.answer_loading = False st.set_page_config(page_title="Document Answering Tool", page_icon=":book:") vector_db = get_vector_db() default_state() init_sidebar() st.header("Document answering tool") st.subheader("Upload your documents on the side and ask questions") init_chat() llm_model = get_llm_model(vector_db) st.session_state.startup = False # React to user input if user_prompt := st.chat_input("What's up?", on_submit=disable_sidebar()): # if st.session_state.answer_loading: # st.warning("Cannot ask multiple questions at the same time") # st.session_state.answer_loading = False # else: start_time = time.time() # Display user message in chat message container with st.chat_message("user"): st.markdown(user_prompt) # Add user message to chat history st.session_state.messages.append({"role": "user", "content": user_prompt}) if llm_model is not None: assistant_chat = st.chat_message("assistant") if not st.session_state.uploaded_docs: assistant_chat.warning("WARN: Will try answer question without documents") with st.spinner('Resolving question...'): res = llm_model({"query": user_prompt}) sources = [] for source_docs in res['source_documents']: if 'source' in source_docs.metadata: sources.append(source_docs.metadata['source']) # Display assistant response in chat message container end_time = time.time() time_taken = "{:.2f}".format(end_time - start_time) format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s" assistant_chat.markdown(format_answer) source_expander = assistant_chat.expander("See full sources") for source_docs in res['source_documents']: if 'source' in source_docs.metadata: format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}" source_expander.markdown(format_source) # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": format_answer}) enable_sidebar() st.rerun()