codeGPT / app.py
red1xe's picture
html templates has been added
4862b9f
raw
history blame
No virus
4.32 kB
import os
import time
import streamlit as st
from dotenv import load_dotenv
from htmlTemplates import css, bot_template, user_template
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFaceHub
from langchain import PromptTemplate
from pdfminer.high_level import extract_text
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Updated Prompt Template
template = """You are an expert on TeamCenter. Use the following pieces of context to answer the question at the end.
If you don't know the answer, it's okay to say that you don't know. Please don't try to make up an answer.
Use two sentences minimum and keep the answer as concise as possible (maximum 200 characters each).
Always use proper grammar and punctuation. End of the answer always say "End of answer" (without quotes).
Context:
{context}
Question: {question}
Helpful Answer (Two sentences minimum, maximum 200 characters each):"""
tokenizer = AutoTokenizer.from_pretrained("red1xe/falcon-7b-codeGPT-3K")
model = AutoModelForSeq2SeqLM.from_pretrained("red1xe/falcon-7b-codeGPT-3K")
## QA_CHAIN_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
load_dotenv()
persist_directory = os.environ.get('PERSIST_DIRECTORY')
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
model_path = os.environ.get('MODEL_PATH')
def get_vector_store(target_source_chunks):
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
retriver = db.as_retriever(search_kwargs={"k": target_source_chunks})
return retriver
def get_conversation_chain(retriever):
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True,)
chain = RetrievalQA.from_llm(
llm=model,
memory=memory,
retriever=retriever,
)
return chain
def handle_userinput(user_question):
if st.session_state.conversation is None:
st.warning("Please load the Vectorstore first!")
return
else:
with st.spinner('Thinking...', ):
start_time = time.time()
response = st.session_state.conversation({'query': user_question})
end_time = time.time()
st.session_state.chat_history = response['chat_history']
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
else:
st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
st.write('Elapsed time: {:.2f} seconds'.format(end_time - start_time))
st.balloons()
def main():
st.set_page_config(page_title='Chat with PDF', page_icon=':rocket:', layout='wide', )
with st.sidebar.title(':gear: Parameters'):
model_n_ctx = st.sidebar.slider('Model N_CTX', min_value=128, max_value=2048, value=1024, step=2)
model_n_batch = st.sidebar.slider('Model N_BATCH', min_value=1, max_value=model_n_ctx, value=512, step=2)
target_source_chunks = st.sidebar.slider('Target Source Chunks', min_value=1, max_value=10, value=4, step=1)
st.write(css, unsafe_allow_html=True)
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
st.header('Chat with PDF :robot_face:')
st.subheader('Upload your PDF file and start chatting with it!')
user_question = st.text_input('Enter your message here:')
if st.button('Start Chain'):
with st.spinner('Working in progress ...'):
vector_store = get_vector_store(target_source_chunks)
st.session_state.conversation = get_conversation_chain(
retriever=vector_store,
)
if user_question:
handle_userinput(user_question)
if __name__ == '__main__':
main()