File size: 3,590 Bytes
eb5a0c9
 
 
 
9cd41db
30fc578
 
2e3feae
86c82f3
eb5a0c9
 
 
 
 
 
 
 
86c82f3
9cd41db
eb5a0c9
9cd41db
 
 
 
eb5a0c9
 
 
 
 
 
 
86c82f3
eb5a0c9
 
 
 
 
 
 
 
86c82f3
eb5a0c9
9cd41db
6f82650
eb5a0c9
 
 
 
 
 
 
 
 
 
 
 
f9aa448
eb5a0c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9aa448
eb5a0c9
 
 
f9aa448
eb5a0c9
 
2d41cae
86c82f3
eb5a0c9
86c82f3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import asyncio
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
from PyPDF2 import PdfReader
import aiohttp
from io import BytesIO

# Set up API key
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]

# Set up prompts
system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)

human_template = "Context:\n{context}\n\nQuestion:\n{question}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)

chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

# Define RetrievalAugmentedQAPipeline class
class RetrievalAugmentedQAPipeline:
    def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:
        self.llm = llm
        self.vector_db = vector_db

    async def arun_pipeline(self, user_query: str):
        context_docs = self.vector_db.similarity_search(user_query, k=2)
        context_list = [doc.page_content for doc in context_docs]
        context_prompt = "\n".join(context_list)
        
        max_context_length = 12000
        if len(context_prompt) > max_context_length:
            context_prompt = context_prompt[:max_context_length]
        
        messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages()

        response = await self.llm.agenerate([messages])
        return {"response": response.generations[0][0].text}

# PDF processing functions
async def fetch_pdf(session, url):
    async with session.get(url) as response:
        if response.status == 200:
            return await response.read()
        else:
            return None

async def process_pdf(pdf_content):
    pdf_reader = PdfReader(BytesIO(pdf_content))
    text = "\n".join([page.extract_text() for page in pdf_reader.pages])
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
    return text_splitter.split_text(text)

@st.cache_resource
def initialize_pipeline():
    return asyncio.run(main())

# Main execution
async def main():
    pdf_urls = [
        "https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf",
        "https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf",
    ]

    all_chunks = []
    async with aiohttp.ClientSession() as session:
        pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])
        
    for pdf_content in pdf_contents:
        if pdf_content:
            chunks = await process_pdf(pdf_content)
            all_chunks.extend(chunks)

    embeddings = OpenAIEmbeddings()
    vector_db = Chroma.from_texts(all_chunks, embeddings)
    
    chat_openai = ChatOpenAI()
    return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)

# Streamlit UI
st.title("Ask About AI!")

pipeline = initialize_pipeline()

user_query = st.text_input("Enter your question about AI:")

if user_query:
    with st.spinner("Generating response..."):
        result = asyncio.run(pipeline.arun_pipeline(user_query))
    
    st.write("Response:")
    st.write(result["response"])