Spaces:
Sleeping
Sleeping
Technocoloredgeek
commited on
Commit
•
9cd41db
1
Parent(s):
30fc578
Update app.py
Browse files
app.py
CHANGED
@@ -2,10 +2,11 @@ import streamlit as st
|
|
2 |
import asyncio
|
3 |
import os
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
-
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_community.embeddings import OpenAIEmbeddings
|
8 |
from langchain.chat_models import ChatOpenAI
|
|
|
9 |
from PyPDF2 import PdfReader
|
10 |
import aiohttp
|
11 |
from io import BytesIO
|
@@ -15,10 +16,12 @@ os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
|
15 |
|
16 |
# Set up prompts
|
17 |
system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
|
18 |
-
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
|
23 |
# Define RetrievalAugmentedQAPipeline class
|
24 |
class RetrievalAugmentedQAPipeline:
|
@@ -35,10 +38,9 @@ class RetrievalAugmentedQAPipeline:
|
|
35 |
if len(context_prompt) > max_context_length:
|
36 |
context_prompt = context_prompt[:max_context_length]
|
37 |
|
38 |
-
|
39 |
-
formatted_user_prompt = user_role_prompt.format(question=user_query, context=context_prompt)
|
40 |
|
41 |
-
response = await self.llm.agenerate([
|
42 |
return {"response": response.generations[0][0].text, "context": context_list}
|
43 |
|
44 |
# PDF processing functions
|
|
|
2 |
import asyncio
|
3 |
import os
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
+
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_community.embeddings import OpenAIEmbeddings
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
+
from langchain.schema import SystemMessage, HumanMessage
|
10 |
from PyPDF2 import PdfReader
|
11 |
import aiohttp
|
12 |
from io import BytesIO
|
|
|
16 |
|
17 |
# Set up prompts
|
18 |
system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
|
19 |
+
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
20 |
|
21 |
+
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
|
22 |
+
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
|
23 |
+
|
24 |
+
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
|
25 |
|
26 |
# Define RetrievalAugmentedQAPipeline class
|
27 |
class RetrievalAugmentedQAPipeline:
|
|
|
38 |
if len(context_prompt) > max_context_length:
|
39 |
context_prompt = context_prompt[:max_context_length]
|
40 |
|
41 |
+
messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages()
|
|
|
42 |
|
43 |
+
response = await self.llm.agenerate([messages])
|
44 |
return {"response": response.generations[0][0].text, "context": context_list}
|
45 |
|
46 |
# PDF processing functions
|