Technocoloredgeek commited on
Commit
9cd41db
1 Parent(s): 30fc578

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -2,10 +2,11 @@ import streamlit as st
2
  import asyncio
3
  import os
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
6
  from langchain_community.vectorstores import Chroma
7
  from langchain_community.embeddings import OpenAIEmbeddings
8
  from langchain.chat_models import ChatOpenAI
 
9
  from PyPDF2 import PdfReader
10
  import aiohttp
11
  from io import BytesIO
@@ -15,10 +16,12 @@ os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
15
 
16
  # Set up prompts
17
  system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
18
- system_role_prompt = SystemMessagePromptTemplate.from_template(system_template)
19
 
20
- user_prompt_template = "Context:\n{context}\n\nQuestion:\n{question}"
21
- user_role_prompt = HumanMessagePromptTemplate.from_template(user_prompt_template)
 
 
22
 
23
  # Define RetrievalAugmentedQAPipeline class
24
  class RetrievalAugmentedQAPipeline:
@@ -35,10 +38,9 @@ class RetrievalAugmentedQAPipeline:
35
  if len(context_prompt) > max_context_length:
36
  context_prompt = context_prompt[:max_context_length]
37
 
38
- formatted_system_prompt = system_role_prompt.format()
39
- formatted_user_prompt = user_role_prompt.format(question=user_query, context=context_prompt)
40
 
41
- response = await self.llm.agenerate([formatted_system_prompt, formatted_user_prompt])
42
  return {"response": response.generations[0][0].text, "context": context_list}
43
 
44
  # PDF processing functions
 
2
  import asyncio
3
  import os
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
6
  from langchain_community.vectorstores import Chroma
7
  from langchain_community.embeddings import OpenAIEmbeddings
8
  from langchain.chat_models import ChatOpenAI
9
+ from langchain.schema import SystemMessage, HumanMessage
10
  from PyPDF2 import PdfReader
11
  import aiohttp
12
  from io import BytesIO
 
16
 
17
  # Set up prompts
18
  system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
19
+ system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
20
 
21
+ human_template = "Context:\n{context}\n\nQuestion:\n{question}"
22
+ human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
23
+
24
+ chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
25
 
26
  # Define RetrievalAugmentedQAPipeline class
27
  class RetrievalAugmentedQAPipeline:
 
38
  if len(context_prompt) > max_context_length:
39
  context_prompt = context_prompt[:max_context_length]
40
 
41
+ messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages()
 
42
 
43
+ response = await self.llm.agenerate([messages])
44
  return {"response": response.generations[0][0].text, "context": context_list}
45
 
46
  # PDF processing functions