Spaces:
Sleeping
Sleeping
Technocoloredgeek
commited on
Commit
•
86c82f3
1
Parent(s):
a4b393c
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTem
|
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_community.embeddings import OpenAIEmbeddings
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
-
from langchain.schema import SystemMessage, HumanMessage
|
10 |
from PyPDF2 import PdfReader
|
11 |
import aiohttp
|
12 |
from io import BytesIO
|
@@ -15,7 +15,7 @@ from io import BytesIO
|
|
15 |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
16 |
|
17 |
# Set up prompts
|
18 |
-
system_template = "
|
19 |
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
20 |
|
21 |
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
|
@@ -29,7 +29,7 @@ class RetrievalAugmentedQAPipeline:
|
|
29 |
self.llm = llm
|
30 |
self.vector_db = vector_db
|
31 |
|
32 |
-
async def arun_pipeline(self, user_query: str
|
33 |
context_docs = self.vector_db.similarity_search(user_query, k=2)
|
34 |
context_list = [doc.page_content for doc in context_docs]
|
35 |
context_prompt = "\n".join(context_list)
|
@@ -38,9 +38,7 @@ class RetrievalAugmentedQAPipeline:
|
|
38 |
if len(context_prompt) > max_context_length:
|
39 |
context_prompt = context_prompt[:max_context_length]
|
40 |
|
41 |
-
messages =
|
42 |
-
messages.extend(chat_history)
|
43 |
-
messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query)))
|
44 |
|
45 |
response = await self.llm.agenerate([messages])
|
46 |
return {"response": response.generations[0][0].text}
|
@@ -88,36 +86,13 @@ async def main():
|
|
88 |
# Streamlit UI
|
89 |
st.title("Ask About AI!")
|
90 |
|
91 |
-
# Initialize session state for chat history
|
92 |
-
if "chat_history" not in st.session_state:
|
93 |
-
st.session_state.chat_history = []
|
94 |
-
|
95 |
pipeline = initialize_pipeline()
|
96 |
|
97 |
-
# Display chat history
|
98 |
-
for message in st.session_state.chat_history:
|
99 |
-
if isinstance(message, HumanMessage):
|
100 |
-
st.write("You:", message.content)
|
101 |
-
elif isinstance(message, AIMessage):
|
102 |
-
st.write("AI:", message.content)
|
103 |
-
|
104 |
user_query = st.text_input("Enter your question about AI:")
|
105 |
|
106 |
if user_query:
|
107 |
-
# Add user message to chat history
|
108 |
-
st.session_state.chat_history.append(HumanMessage(content=user_query))
|
109 |
-
|
110 |
with st.spinner("Generating response..."):
|
111 |
-
result = asyncio.run(pipeline.arun_pipeline(user_query
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
st.session_state.chat_history.append(ai_message)
|
116 |
-
|
117 |
-
# Display the latest response
|
118 |
-
st.write("AI:", result["response"])
|
119 |
-
|
120 |
-
# Add a button to clear chat history
|
121 |
-
if st.button("Clear Chat History"):
|
122 |
-
st.session_state.chat_history = []
|
123 |
-
st.experimental_rerun()
|
|
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_community.embeddings import OpenAIEmbeddings
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
+
from langchain.schema import SystemMessage, HumanMessage
|
10 |
from PyPDF2 import PdfReader
|
11 |
import aiohttp
|
12 |
from io import BytesIO
|
|
|
15 |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
16 |
|
17 |
# Set up prompts
|
18 |
+
system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
|
19 |
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
20 |
|
21 |
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
|
|
|
29 |
self.llm = llm
|
30 |
self.vector_db = vector_db
|
31 |
|
32 |
+
async def arun_pipeline(self, user_query: str):
|
33 |
context_docs = self.vector_db.similarity_search(user_query, k=2)
|
34 |
context_list = [doc.page_content for doc in context_docs]
|
35 |
context_prompt = "\n".join(context_list)
|
|
|
38 |
if len(context_prompt) > max_context_length:
|
39 |
context_prompt = context_prompt[:max_context_length]
|
40 |
|
41 |
+
messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages()
|
|
|
|
|
42 |
|
43 |
response = await self.llm.agenerate([messages])
|
44 |
return {"response": response.generations[0][0].text}
|
|
|
86 |
# Streamlit UI
|
87 |
st.title("Ask About AI!")
|
88 |
|
|
|
|
|
|
|
|
|
89 |
pipeline = initialize_pipeline()
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
user_query = st.text_input("Enter your question about AI:")
|
92 |
|
93 |
if user_query:
|
|
|
|
|
|
|
94 |
with st.spinner("Generating response..."):
|
95 |
+
result = asyncio.run(pipeline.arun_pipeline(user_query))
|
96 |
|
97 |
+
st.write("Response:")
|
98 |
+
st.write(result["response"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|