yilin1344 commited on
Commit
7e703d7
1 Parent(s): 402e31a
Files changed (1) hide show
  1. app.py +46 -89
app.py CHANGED
@@ -1,110 +1,67 @@
1
- import chainlit as cl
2
- from langchain.embeddings.openai import OpenAIEmbeddings
3
- from langchain.document_loaders.csv_loader import CSVLoader
4
- from langchain.embeddings import CacheBackedEmbeddings
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.vectorstores import FAISS
7
- from langchain.chains import RetrievalQA
8
- from langchain.chat_models import ChatOpenAI
9
- from langchain.storage import LocalFileStore
10
- from langchain.prompts.chat import (
11
- ChatPromptTemplate,
12
- SystemMessagePromptTemplate,
13
- HumanMessagePromptTemplate,
14
  )
 
15
  import chainlit as cl
16
 
17
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
18
-
19
- system_template = """
20
- Use the following pieces of context to answer the user's question.
21
- Please respond as if you were Ken from the movie Barbie. Ken is a well-meaning but naive character who loves to Beach. He talks like a typical Californian Beach Bro, but he doesn't use the word "Dude" so much.
22
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
23
- You can make inferences based on the context as long as it still faithfully represents the feedback.
24
-
25
- Example of your response should be:
26
 
27
- ```
28
- The answer is foo
29
- ```
30
 
31
- Begin!
32
- ----------------
33
- {context}"""
 
 
 
 
34
 
35
- messages = [
36
- SystemMessagePromptTemplate.from_template(system_template),
37
- HumanMessagePromptTemplate.from_template("{question}"),
38
- ]
39
- prompt = ChatPromptTemplate(messages=messages)
40
- chain_type_kwargs = {"prompt": prompt}
41
 
42
- @cl.author_rename
43
- def rename(orig_author: str):
44
- rename_dict = {"RetrievalQA": "Consulting The Kens"}
45
- return rename_dict.get(orig_author, orig_author)
46
 
47
  @cl.on_chat_start
48
- async def init():
49
- msg = cl.Message(content=f"Building Index...")
50
- await msg.send()
51
-
52
- # build FAISS index from csv
53
- loader = CSVLoader(file_path="./data/barbie.csv", source_column="Review_Url")
54
- data = loader.load()
55
- documents = text_splitter.transform_documents(data)
56
- store = LocalFileStore("./cache/")
57
- core_embeddings_model = OpenAIEmbeddings()
58
- embedder = CacheBackedEmbeddings.from_bytes_store(
59
- core_embeddings_model, store, namespace=core_embeddings_model.model
60
  )
61
- # make async docsearch
62
- docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
63
-
64
- chain = RetrievalQA.from_chain_type(
65
- ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
66
- chain_type="stuff",
67
- return_source_documents=True,
68
- retriever=docsearch.as_retriever(),
69
- chain_type_kwargs = {"prompt": prompt}
70
  )
71
 
72
- msg.content = f"Index built!"
73
- await msg.send()
 
 
74
 
75
- cl.user_session.set("chain", chain)
76
 
77
 
78
  @cl.on_message
79
  async def main(message):
80
- chain = cl.user_session.get("chain")
81
- cb = cl.AsyncLangchainCallbackHandler(
82
- stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
83
- )
84
- cb.answer_reached = True
85
- res = await chain.acall(message, callbacks=[cb], )
86
-
87
- answer = res["result"]
88
- source_elements = []
89
- visited_sources = set()
90
 
91
- # Get the documents from the user session
92
- docs = res["source_documents"]
93
- metadatas = [doc.metadata for doc in docs]
94
- all_sources = [m["source"] for m in metadatas]
95
 
96
- for source in all_sources:
97
- if source in visited_sources:
98
- continue
99
- visited_sources.add(source)
100
- # Create the text element referenced in the message
101
- source_elements.append(
102
- cl.Text(content="https://www.imdb.com" + source, name="Review URL")
103
- )
104
 
105
- if source_elements:
106
- answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
107
- else:
108
- answer += "\nNo sources found"
109
 
110
- await cl.Message(content=answer, elements=source_elements).send()
 
1
+ import os
2
+ import openai
3
+
4
+ from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
5
+ from llama_index.callbacks.base import CallbackManager
6
+ from llama_index import (
7
+ LLMPredictor,
8
+ ServiceContext,
9
+ StorageContext,
10
+ load_index_from_storage,
 
 
 
11
  )
12
+ from langchain.chat_models import ChatOpenAI
13
  import chainlit as cl
14
 
 
 
 
 
 
 
 
 
 
15
 
16
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
 
 
17
 
18
+ try:
19
+ # rebuild storage context
20
+ storage_context = StorageContext.from_defaults(persist_dir="./storage")
21
+ # load index
22
+ index = load_index_from_storage(storage_context)
23
+ except:
24
+ from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader
25
 
26
+ documents = SimpleDirectoryReader("./data").load_data()
27
+ index = GPTVectorStoreIndex.from_documents(documents)
28
+ index.storage_context.persist()
 
 
 
29
 
 
 
 
 
30
 
31
  @cl.on_chat_start
32
+ async def factory():
33
+ llm_predictor = LLMPredictor(
34
+ llm=ChatOpenAI(
35
+ temperature=0,
36
+ model_name="gpt-3.5-turbo",
37
+ streaming=True,
38
+ ),
 
 
 
 
 
39
  )
40
+ service_context = ServiceContext.from_defaults(
41
+ llm_predictor=llm_predictor,
42
+ chunk_size=512,
43
+ callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
 
 
 
 
 
44
  )
45
 
46
+ query_engine = index.as_query_engine(
47
+ service_context=service_context,
48
+ streaming=True,
49
+ )
50
 
51
+ cl.user_session.set("query_engine", query_engine)
52
 
53
 
54
  @cl.on_message
55
  async def main(message):
56
+ query_engine = cl.user_session.get("query_engine") # type: RetrieverQueryEngine
57
+ response = await cl.make_async(query_engine.query)(message)
 
 
 
 
 
 
 
 
58
 
59
+ response_message = cl.Message(content="")
 
 
 
60
 
61
+ for token in response.response_gen:
62
+ await response_message.stream_token(token=token)
 
 
 
 
 
 
63
 
64
+ if response.response_txt:
65
+ response_message.content = response.response_txt
 
 
66
 
67
+ await response_message.send()