File size: 4,321 Bytes
4862b9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f87533
7f51a5a
4862b9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44131c9
4862b9f
 
44131c9
9861a8a
4862b9f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import time
import streamlit as st
from dotenv import load_dotenv
from htmlTemplates import css, bot_template, user_template
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFaceHub
from langchain import PromptTemplate
from pdfminer.high_level import extract_text
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


# Updated Prompt Template
template = """You are an expert on TeamCenter. Use the following pieces of context to answer the question at the end. 
If you don't know the answer, it's okay to say that you don't know. Please don't try to make up an answer. 
Use two sentences minimum and keep the answer as concise as possible (maximum 200 characters each).
Always use proper grammar and punctuation. End of the answer always say "End of answer" (without quotes).

Context:
{context}

Question: {question}
Helpful Answer (Two sentences minimum, maximum 200 characters each):"""

tokenizer = AutoTokenizer.from_pretrained("red1xe/falcon-7b-codeGPT-3K")
model = AutoModelForSeq2SeqLM.from_pretrained("red1xe/falcon-7b-codeGPT-3K")
## QA_CHAIN_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])

load_dotenv()
persist_directory = os.environ.get('PERSIST_DIRECTORY')
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
model_path = os.environ.get('MODEL_PATH')

def get_vector_store(target_source_chunks):
    embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
    db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
    retriver = db.as_retriever(search_kwargs={"k": target_source_chunks})
    return retriver

def get_conversation_chain(retriever):
    memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True,)
    chain = RetrievalQA.from_llm(
        llm=model,
        memory=memory,
        retriever=retriever,
    )
    return chain


def handle_userinput(user_question):
    if st.session_state.conversation is None:
        st.warning("Please load the Vectorstore first!")
        return
    else:
        with st.spinner('Thinking...', ):
            start_time = time.time()
            response = st.session_state.conversation({'query': user_question})
            end_time = time.time()

            st.session_state.chat_history = response['chat_history']

            for i, message in enumerate(st.session_state.chat_history):
                if i % 2 == 0:
                    st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
                else:
                    st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)

            st.write('Elapsed time: {:.2f} seconds'.format(end_time - start_time))
            st.balloons()




def main():

    st.set_page_config(page_title='Chat with PDF', page_icon=':rocket:', layout='wide', )
    with st.sidebar.title(':gear: Parameters'):
        model_n_ctx = st.sidebar.slider('Model N_CTX', min_value=128, max_value=2048, value=1024, step=2)
        model_n_batch = st.sidebar.slider('Model N_BATCH', min_value=1, max_value=model_n_ctx, value=512, step=2)
        target_source_chunks = st.sidebar.slider('Target Source Chunks', min_value=1, max_value=10, value=4, step=1)
    st.write(css, unsafe_allow_html=True)

    if "conversation" not in st.session_state:
        st.session_state.conversation = None
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = None

    st.header('Chat with PDF :robot_face:')
    st.subheader('Upload your PDF file and start chatting with it!')
    user_question = st.text_input('Enter your message here:')

    if st.button('Start Chain'):
        with st.spinner('Working in progress ...'):
            vector_store = get_vector_store(target_source_chunks)
            st.session_state.conversation = get_conversation_chain(
                retriever=vector_store,
            )

    if user_question:
        handle_userinput(user_question)


if __name__ == '__main__':
    main()