Technocoloredgeek commited on
Commit
86c82f3
1 Parent(s): a4b393c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -32
app.py CHANGED
@@ -6,7 +6,7 @@ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTem
6
  from langchain_community.vectorstores import Chroma
7
  from langchain_community.embeddings import OpenAIEmbeddings
8
  from langchain.chat_models import ChatOpenAI
9
- from langchain.schema import SystemMessage, HumanMessage, AIMessage
10
  from PyPDF2 import PdfReader
11
  import aiohttp
12
  from io import BytesIO
@@ -15,7 +15,7 @@ from io import BytesIO
15
  os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
16
 
17
  # Set up prompts
18
- system_template = "You are an AI assistant answering questions about AI. Use the following context to answer the user's question. If you cannot find the answer in the context, say you don't know the answer but you can try to help with related information."
19
  system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
20
 
21
  human_template = "Context:\n{context}\n\nQuestion:\n{question}"
@@ -29,7 +29,7 @@ class RetrievalAugmentedQAPipeline:
29
  self.llm = llm
30
  self.vector_db = vector_db
31
 
32
- async def arun_pipeline(self, user_query: str, chat_history: list):
33
  context_docs = self.vector_db.similarity_search(user_query, k=2)
34
  context_list = [doc.page_content for doc in context_docs]
35
  context_prompt = "\n".join(context_list)
@@ -38,9 +38,7 @@ class RetrievalAugmentedQAPipeline:
38
  if len(context_prompt) > max_context_length:
39
  context_prompt = context_prompt[:max_context_length]
40
 
41
- messages = [SystemMessage(content=system_template)]
42
- messages.extend(chat_history)
43
- messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query)))
44
 
45
  response = await self.llm.agenerate([messages])
46
  return {"response": response.generations[0][0].text}
@@ -88,36 +86,13 @@ async def main():
88
  # Streamlit UI
89
  st.title("Ask About AI!")
90
 
91
- # Initialize session state for chat history
92
- if "chat_history" not in st.session_state:
93
- st.session_state.chat_history = []
94
-
95
  pipeline = initialize_pipeline()
96
 
97
- # Display chat history
98
- for message in st.session_state.chat_history:
99
- if isinstance(message, HumanMessage):
100
- st.write("You:", message.content)
101
- elif isinstance(message, AIMessage):
102
- st.write("AI:", message.content)
103
-
104
  user_query = st.text_input("Enter your question about AI:")
105
 
106
  if user_query:
107
- # Add user message to chat history
108
- st.session_state.chat_history.append(HumanMessage(content=user_query))
109
-
110
  with st.spinner("Generating response..."):
111
- result = asyncio.run(pipeline.arun_pipeline(user_query, st.session_state.chat_history))
112
 
113
- # Add AI response to chat history
114
- ai_message = AIMessage(content=result["response"])
115
- st.session_state.chat_history.append(ai_message)
116
-
117
- # Display the latest response
118
- st.write("AI:", result["response"])
119
-
120
- # Add a button to clear chat history
121
- if st.button("Clear Chat History"):
122
- st.session_state.chat_history = []
123
- st.experimental_rerun()
 
6
  from langchain_community.vectorstores import Chroma
7
  from langchain_community.embeddings import OpenAIEmbeddings
8
  from langchain.chat_models import ChatOpenAI
9
+ from langchain.schema import SystemMessage, HumanMessage
10
  from PyPDF2 import PdfReader
11
  import aiohttp
12
  from io import BytesIO
 
15
  os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
16
 
17
  # Set up prompts
18
+ system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
19
  system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
20
 
21
  human_template = "Context:\n{context}\n\nQuestion:\n{question}"
 
29
  self.llm = llm
30
  self.vector_db = vector_db
31
 
32
+ async def arun_pipeline(self, user_query: str):
33
  context_docs = self.vector_db.similarity_search(user_query, k=2)
34
  context_list = [doc.page_content for doc in context_docs]
35
  context_prompt = "\n".join(context_list)
 
38
  if len(context_prompt) > max_context_length:
39
  context_prompt = context_prompt[:max_context_length]
40
 
41
+ messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages()
 
 
42
 
43
  response = await self.llm.agenerate([messages])
44
  return {"response": response.generations[0][0].text}
 
86
  # Streamlit UI
87
  st.title("Ask About AI!")
88
 
 
 
 
 
89
  pipeline = initialize_pipeline()
90
 
 
 
 
 
 
 
 
91
  user_query = st.text_input("Enter your question about AI:")
92
 
93
  if user_query:
 
 
 
94
  with st.spinner("Generating response..."):
95
+ result = asyncio.run(pipeline.arun_pipeline(user_query))
96
 
97
+ st.write("Response:")
98
+ st.write(result["response"])