File size: 2,967 Bytes
5795fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f03c543
5795fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f03c543
 
 
 
 
 
 
5795fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
from typing import List

import chainlit as cl

from llama_index.callbacks.base import CallbackManager
from llama_index import (
    ServiceContext,
    StorageContext,
    load_index_from_storage,
)
from llama_index.llms import OpenAI
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.query_engine import SubQuestionQueryEngine
from llama_index.embeddings import HuggingFaceEmbedding
from chainlit.types import AskFileResponse
from llama_index import download_loader
from llama_index import VectorStoreIndex


def process_file(file: AskFileResponse):
    import tempfile

    with tempfile.NamedTemporaryFile(mode="w", delete=False) as tempfile:
        with open(tempfile.name, "wb") as f:
            f.write(file.content)

    PDFReader = download_loader("PDFReader")

    loader = PDFReader()

    documents = loader.load_data(tempfile.name)
    return documents


@cl.on_chat_start
async def on_chat_start():
    files = None

    # Wait for the user to upload a file
    while files == None:
        files = await cl.AskFileMessage(
            content="Please upload a PDF file to begin!",
            accept=["application/pdf"],
            max_size_mb=20,
            timeout=180,
        ).send()

    file = files[0]

    msg = cl.Message(
        content=f"Processing `{file.name}`...", disable_human_feedback=True
    )
    await msg.send()

    # load the file
    documents = process_file(file)

    context = ServiceContext.from_defaults(
        embed_model=HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned")
    )

    index = VectorStoreIndex.from_documents(
        documents=documents, context=context, show_progress=True
    )

    llm = OpenAI(model="gpt-4-1106-preview", temperature=0)

    embed_model = HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned")

    service_context = ServiceContext.from_defaults(
        embed_model=embed_model,
        llm=llm,
    )

    cohere_rerank = CohereRerank(top_n=5)

    query_engine = index.as_query_engine(
        similarity_top_k=10,
        node_postprocessors=[cohere_rerank],
        service_context=service_context,
    )

    query_engine_tools = [
        QueryEngineTool(
            query_engine=query_engine,
            metadata=ToolMetadata(
                name="mit_theses",
                description="A collection of MIT theses.",
            ),
        ),
    ]

    query_engine = SubQuestionQueryEngine.from_defaults(
        query_engine_tools=query_engine_tools,
        service_context=service_context,
    )

    cl.user_session.set("query_engine", query_engine)


@cl.on_message
async def main(message: cl.Message):
    query_engine = cl.user_session.get("query_engine")
    response = await cl.make_async(query_engine.query)(message.content)

    response_message = cl.Message(content=str(response))

    await response_message.send()