Technocoloredgeek commited on
Commit
eb5a0c9
1 Parent(s): e9ac804

Update app.py

Browse files

Fixing display issue

Files changed (1) hide show
  1. app.py +103 -159
app.py CHANGED
@@ -1,159 +1,103 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 6,
6
- "metadata": {},
7
- "outputs": [
8
- {
9
- "name": "stdout",
10
- "output_type": "stream",
11
- "text": [
12
- "Note: you may need to restart the kernel to use updated packages.\n",
13
- "Created 915 chunks from 2 PDF files\n",
14
- "Query: What are the key principles of the AI Bill of Rights?\n",
15
- "\n",
16
- "Response:\n",
17
- "The key principles of the AI Bill of Rights are civil rights, civil liberties, and privacy.\n",
18
- "\n",
19
- "Context used:\n",
20
- "1. use, and deployment of automated systems to protect the rights of the American public in the age of ...\n",
21
- "2. civil rights, civil liberties, and privacy. The Blueprint for an AI Bill of Rights includes this For...\n"
22
- ]
23
- }
24
- ],
25
- "source": [
26
- "# Cell 1: Install required packages\n",
27
- "%pip install langchain openai chromadb PyPDF2 tiktoken -qU\n",
28
- "\n",
29
- "# Cell 2: Import necessary modules\n",
30
- "import os\n",
31
- "import tempfile\n",
32
- "import aiohttp\n",
33
- "import asyncio\n",
34
- "import getpass\n",
35
- "from io import BytesIO\n",
36
- "from typing import List\n",
37
- "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
38
- "from langchain.document_loaders import PyPDFLoader\n",
39
- "from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate\n",
40
- "from langchain.vectorstores import Chroma\n",
41
- "from langchain.embeddings import OpenAIEmbeddings\n",
42
- "from langchain.chat_models import ChatOpenAI\n",
43
- "from PyPDF2 import PdfReader\n",
44
- "\n",
45
- "\n",
46
- "# Cell 4: Set up prompts\n",
47
- "system_template = \"Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer.\"\n",
48
- "system_role_prompt = SystemMessagePromptTemplate.from_template(system_template)\n",
49
- "\n",
50
- "user_prompt_template = \"Context:\\n{context}\\n\\nQuestion:\\n{question}\"\n",
51
- "user_role_prompt = HumanMessagePromptTemplate.from_template(user_prompt_template)\n",
52
- "\n",
53
- "# Cell 5: Define RetrievalAugmentedQAPipeline class\n",
54
- "class RetrievalAugmentedQAPipeline:\n",
55
- " def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:\n",
56
- " self.llm = llm\n",
57
- " self.vector_db = vector_db\n",
58
- "\n",
59
- " async def arun_pipeline(self, user_query: str):\n",
60
- " context_docs = self.vector_db.similarity_search(user_query, k=2) # Reduced from 4 to 2\n",
61
- " context_list = [doc.page_content for doc in context_docs]\n",
62
- " context_prompt = \"\\n\".join(context_list)\n",
63
- " \n",
64
- " # Implement a simple truncation to ensure we don't exceed token limit\n",
65
- " max_context_length = 12000 # Adjust this value as needed\n",
66
- " if len(context_prompt) > max_context_length:\n",
67
- " context_prompt = context_prompt[:max_context_length]\n",
68
- " \n",
69
- " formatted_system_prompt = system_role_prompt.format()\n",
70
- " formatted_user_prompt = user_role_prompt.format(question=user_query, context=context_prompt)\n",
71
- "\n",
72
- " async def generate_response():\n",
73
- " async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):\n",
74
- " yield chunk.content\n",
75
- "\n",
76
- " return {\"response\": generate_response(), \"context\": context_list}\n",
77
- "\n",
78
- "# Cell 6: PDF processing functions\n",
79
- "async def fetch_pdf(session, url):\n",
80
- " async with session.get(url) as response:\n",
81
- " if response.status == 200:\n",
82
- " return await response.read()\n",
83
- " else:\n",
84
- " print(f\"Failed to fetch PDF from {url}\")\n",
85
- " return None\n",
86
- "\n",
87
- "async def process_pdf(pdf_content):\n",
88
- " pdf_reader = PdfReader(BytesIO(pdf_content))\n",
89
- " text = \"\\n\".join([page.extract_text() for page in pdf_reader.pages])\n",
90
- " text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)\n",
91
- " return text_splitter.split_text(text)\n",
92
- "\n",
93
- "# Cell 7: Main execution\n",
94
- "async def main():\n",
95
- " # Ensure API key is set\n",
96
- " api_key = get_openai_api_key()\n",
97
- "\n",
98
- " # List of PDF URLs\n",
99
- " pdf_urls = [\n",
100
- " \"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf\",\n",
101
- " \"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf\",\n",
102
- " ]\n",
103
- "\n",
104
- " all_chunks = []\n",
105
- " async with aiohttp.ClientSession() as session:\n",
106
- " pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])\n",
107
- " \n",
108
- " for pdf_content in pdf_contents:\n",
109
- " if pdf_content:\n",
110
- " chunks = await process_pdf(pdf_content)\n",
111
- " all_chunks.extend(chunks)\n",
112
- "\n",
113
- " print(f\"Created {len(all_chunks)} chunks from {len(pdf_urls)} PDF files\")\n",
114
- "\n",
115
- " embeddings = OpenAIEmbeddings(openai_api_key=api_key)\n",
116
- " vector_db = Chroma.from_texts(all_chunks, embeddings)\n",
117
- " \n",
118
- " chat_openai = ChatOpenAI(openai_api_key=api_key)\n",
119
- " retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)\n",
120
- " \n",
121
- " # Example query\n",
122
- " query = \"What are the key principles of the AI Bill of Rights?\"\n",
123
- " result = await retrieval_augmented_qa_pipeline.arun_pipeline(query)\n",
124
- " \n",
125
- " print(\"Query:\", query)\n",
126
- " print(\"\\nResponse:\")\n",
127
- " async for chunk in result[\"response\"]:\n",
128
- " print(chunk, end=\"\")\n",
129
- " print(\"\\n\\nContext used:\")\n",
130
- " for i, context in enumerate(result[\"context\"], 1):\n",
131
- " print(f\"{i}. {context[:100]}...\")\n",
132
- "\n",
133
- "# Cell 8: Run the main function\n",
134
- "await main()"
135
- ]
136
- }
137
- ],
138
- "metadata": {
139
- "kernelspec": {
140
- "display_name": "base",
141
- "language": "python",
142
- "name": "python3"
143
- },
144
- "language_info": {
145
- "codemirror_mode": {
146
- "name": "ipython",
147
- "version": 3
148
- },
149
- "file_extension": ".py",
150
- "mimetype": "text/x-python",
151
- "name": "python",
152
- "nbconvert_exporter": "python",
153
- "pygments_lexer": "ipython3",
154
- "version": "3.10.14"
155
- }
156
- },
157
- "nbformat": 4,
158
- "nbformat_minor": 2
159
- }
 
1
+ import streamlit as st
2
+ import asyncio
3
+ import os
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.document_loaders import PyPDFLoader
6
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.embeddings import OpenAIEmbeddings
9
+ from langchain.chat_models import ChatOpenAI
10
+ from PyPDF2 import PdfReader
11
+ import aiohttp
12
+ from io import BytesIO
13
+
14
+ # Set up API key
15
+ os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
16
+
17
+ # Set up prompts
18
+ system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
19
+ system_role_prompt = SystemMessagePromptTemplate.from_template(system_template)
20
+
21
+ user_prompt_template = "Context:\n{context}\n\nQuestion:\n{question}"
22
+ user_role_prompt = HumanMessagePromptTemplate.from_template(user_prompt_template)
23
+
24
+ # Define RetrievalAugmentedQAPipeline class
25
+ class RetrievalAugmentedQAPipeline:
26
+ def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:
27
+ self.llm = llm
28
+ self.vector_db = vector_db
29
+
30
+ async def arun_pipeline(self, user_query: str):
31
+ context_docs = self.vector_db.similarity_search(user_query, k=2)
32
+ context_list = [doc.page_content for doc in context_docs]
33
+ context_prompt = "\n".join(context_list)
34
+
35
+ max_context_length = 12000
36
+ if len(context_prompt) > max_context_length:
37
+ context_prompt = context_prompt[:max_context_length]
38
+
39
+ formatted_system_prompt = system_role_prompt.format()
40
+ formatted_user_prompt = user_role_prompt.format(question=user_query, context=context_prompt)
41
+
42
+ response = await self.llm.agenerate([formatted_system_prompt, formatted_user_prompt])
43
+ return {"response": response.generations[0][0].text, "context": context_list}
44
+
45
+ # PDF processing functions
46
+ async def fetch_pdf(session, url):
47
+ async with session.get(url) as response:
48
+ if response.status == 200:
49
+ return await response.read()
50
+ else:
51
+ st.error(f"Failed to fetch PDF from {url}")
52
+ return None
53
+
54
+ async def process_pdf(pdf_content):
55
+ pdf_reader = PdfReader(BytesIO(pdf_content))
56
+ text = "\n".join([page.extract_text() for page in pdf_reader.pages])
57
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
58
+ return text_splitter.split_text(text)
59
+
60
+ @st.cache_resource
61
+ def initialize_pipeline():
62
+ return asyncio.run(main())
63
+
64
+ # Main execution
65
+ async def main():
66
+ pdf_urls = [
67
+ "https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf",
68
+ "https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf",
69
+ ]
70
+
71
+ all_chunks = []
72
+ async with aiohttp.ClientSession() as session:
73
+ pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])
74
+
75
+ for pdf_content in pdf_contents:
76
+ if pdf_content:
77
+ chunks = await process_pdf(pdf_content)
78
+ all_chunks.extend(chunks)
79
+
80
+ st.write(f"Created {len(all_chunks)} chunks from {len(pdf_urls)} PDF files")
81
+
82
+ embeddings = OpenAIEmbeddings()
83
+ vector_db = Chroma.from_texts(all_chunks, embeddings)
84
+
85
+ chat_openai = ChatOpenAI()
86
+ return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)
87
+
88
+ # Streamlit UI
89
+ st.title("AI Bill of Rights Q&A")
90
+
91
+ pipeline = initialize_pipeline()
92
+
93
+ user_query = st.text_input("Enter your question about the AI Bill of Rights:")
94
+
95
+ if user_query:
96
+ result = asyncio.run(pipeline.arun_pipeline(user_query))
97
+
98
+ st.write("Response:")
99
+ st.write(result["response"])
100
+
101
+ st.write("Context used:")
102
+ for i, context in enumerate(result["context"], 1):
103
+ st.write(f"{i}. {context[:100]}...")