import os import time import streamlit as st from dotenv import load_dotenv from htmlTemplates import css, bot_template, user_template from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Chroma from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain.llms import HuggingFaceHub from langchain import PromptTemplate from pdfminer.high_level import extract_text from langchain.text_splitter import RecursiveCharacterTextSplitter from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Updated Prompt Template template = """You are an expert on TeamCenter. Use the following pieces of context to answer the question at the end. If you don't know the answer, it's okay to say that you don't know. Please don't try to make up an answer. Use two sentences minimum and keep the answer as concise as possible (maximum 200 characters each). Always use proper grammar and punctuation. End of the answer always say "End of answer" (without quotes). Context: {context} Question: {question} Helpful Answer (Two sentences minimum, maximum 200 characters each):""" tokenizer = AutoTokenizer.from_pretrained("red1xe/falcon-7b-codeGPT-3K") model = AutoModelForSeq2SeqLM.from_pretrained("red1xe/falcon-7b-codeGPT-3K") ## QA_CHAIN_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"]) load_dotenv() persist_directory = os.environ.get('PERSIST_DIRECTORY') embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") model_path = os.environ.get('MODEL_PATH') def get_vector_store(target_source_chunks): embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) db = Chroma(persist_directory=persist_directory, embedding_function=embeddings) retriver = db.as_retriever(search_kwargs={"k": target_source_chunks}) return retriver def get_conversation_chain(retriever): memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True,) chain = RetrievalQA.from_llm( llm=model, memory=memory, retriever=retriever, ) return chain def handle_userinput(user_question): if st.session_state.conversation is None: st.warning("Please load the Vectorstore first!") return else: with st.spinner('Thinking...', ): start_time = time.time() response = st.session_state.conversation({'query': user_question}) end_time = time.time() st.session_state.chat_history = response['chat_history'] for i, message in enumerate(st.session_state.chat_history): if i % 2 == 0: st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True) else: st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True) st.write('Elapsed time: {:.2f} seconds'.format(end_time - start_time)) st.balloons() def main(): st.set_page_config(page_title='Chat with PDF', page_icon=':rocket:', layout='wide', ) with st.sidebar.title(':gear: Parameters'): model_n_ctx = st.sidebar.slider('Model N_CTX', min_value=128, max_value=2048, value=1024, step=2) model_n_batch = st.sidebar.slider('Model N_BATCH', min_value=1, max_value=model_n_ctx, value=512, step=2) target_source_chunks = st.sidebar.slider('Target Source Chunks', min_value=1, max_value=10, value=4, step=1) st.write(css, unsafe_allow_html=True) if "conversation" not in st.session_state: st.session_state.conversation = None if "chat_history" not in st.session_state: st.session_state.chat_history = None st.header('Chat with PDF :robot_face:') st.subheader('Upload your PDF file and start chatting with it!') user_question = st.text_input('Enter your message here:') if st.button('Start Chain'): with st.spinner('Working in progress ...'): vector_store = get_vector_store(target_source_chunks) st.session_state.conversation = get_conversation_chain( retriever=vector_store, ) if user_question: handle_userinput(user_question) if __name__ == '__main__': main()