import os import asyncio from typing import List from chainlit.types import AskFileResponse from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain_community.vectorstores import Chroma from langchain_community.embeddings import OpenAIEmbeddings from langchain.chat_models import ChatOpenAI from langchain.schema import SystemMessage, HumanMessage from PyPDF2 import PdfReader import chainlit as cl # Check if the API key is set if not os.getenv("OPENAI_API_KEY"): raise ValueError("OPENAI_API_KEY environment variable is not set") # Set up prompts system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer." system_message_prompt = SystemMessagePromptTemplate.from_template(system_template) human_template = "Context:\n{context}\n\nQuestion:\n{question}" human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) class RetrievalAugmentedQAPipeline: def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None: self.llm = llm self.vector_db = vector_db async def arun_pipeline(self, user_query: str): context_docs = self.vector_db.similarity_search(user_query, k=2) context_list = [doc.page_content for doc in context_docs] context_prompt = "\n".join(context_list) max_context_length = 12000 if len(context_prompt) > max_context_length: context_prompt = context_prompt[:max_context_length] messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages() async for chunk in self.llm.astream(messages): yield chunk.content def process_pdf(file: AskFileResponse) -> List[str]: pdf_reader = PdfReader(file.content) text = "\n".join([page.extract_text() for page in pdf_reader.pages]) text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40) return text_splitter.split_text(text) @cl.on_chat_start async def on_chat_start(): files = await cl.AskFileMessage( content="Please upload a PDF file to begin!", accept=["application/pdf"], max_size_mb=20, ).send() if not files: await cl.Message(content="No file was uploaded. Please try again.").send() return file = files[0] msg = cl.Message(content=f"Processing `{file.name}`...") await msg.send() texts = process_pdf(file) embeddings = OpenAIEmbeddings() vector_db = Chroma.from_texts(texts, embeddings) chat_openai = ChatOpenAI() retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai) cl.user_session.set("pipeline", retrieval_augmented_qa_pipeline) msg.content = f"Processing `{file.name}` done. You can now ask questions!" await msg.update() @cl.on_message async def main(message: cl.Message): pipeline = cl.user_session.get("pipeline") if not pipeline: await cl.Message(content="Please upload a PDF file first.").send() return msg = cl.Message(content="") try: async for chunk in pipeline.arun_pipeline(message.content): await msg.stream_token(chunk) except Exception as e: await cl.Message(content=f"An error occurred: {str(e)}").send() return await msg.send()