ChatLGO / app.py
CSAle's picture
Releasing ChatLGONoData
f03c543
raw
history blame
No virus
2.97 kB
import os
from typing import List
import chainlit as cl
from llama_index.callbacks.base import CallbackManager
from llama_index import (
ServiceContext,
StorageContext,
load_index_from_storage,
)
from llama_index.llms import OpenAI
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.query_engine import SubQuestionQueryEngine
from llama_index.embeddings import HuggingFaceEmbedding
from chainlit.types import AskFileResponse
from llama_index import download_loader
from llama_index import VectorStoreIndex
def process_file(file: AskFileResponse):
import tempfile
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tempfile:
with open(tempfile.name, "wb") as f:
f.write(file.content)
PDFReader = download_loader("PDFReader")
loader = PDFReader()
documents = loader.load_data(tempfile.name)
return documents
@cl.on_chat_start
async def on_chat_start():
files = None
# Wait for the user to upload a file
while files == None:
files = await cl.AskFileMessage(
content="Please upload a PDF file to begin!",
accept=["application/pdf"],
max_size_mb=20,
timeout=180,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...", disable_human_feedback=True
)
await msg.send()
# load the file
documents = process_file(file)
context = ServiceContext.from_defaults(
embed_model=HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned")
)
index = VectorStoreIndex.from_documents(
documents=documents, context=context, show_progress=True
)
llm = OpenAI(model="gpt-4-1106-preview", temperature=0)
embed_model = HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned")
service_context = ServiceContext.from_defaults(
embed_model=embed_model,
llm=llm,
)
cohere_rerank = CohereRerank(top_n=5)
query_engine = index.as_query_engine(
similarity_top_k=10,
node_postprocessors=[cohere_rerank],
service_context=service_context,
)
query_engine_tools = [
QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="mit_theses",
description="A collection of MIT theses.",
),
),
]
query_engine = SubQuestionQueryEngine.from_defaults(
query_engine_tools=query_engine_tools,
service_context=service_context,
)
cl.user_session.set("query_engine", query_engine)
@cl.on_message
async def main(message: cl.Message):
query_engine = cl.user_session.get("query_engine")
response = await cl.make_async(query_engine.query)(message.content)
response_message = cl.Message(content=str(response))
await response_message.send()