IlyaGusev commited on
Commit
8e5a242
1 Parent(s): 284b100

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +252 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from huggingface_hub import snapshot_download
4
+ from langchain.document_loaders import (
5
+ CSVLoader,
6
+ EverNoteLoader,
7
+ PDFMinerLoader,
8
+ TextLoader,
9
+ UnstructuredEmailLoader,
10
+ UnstructuredEPubLoader,
11
+ UnstructuredHTMLLoader,
12
+ UnstructuredMarkdownLoader,
13
+ UnstructuredODTLoader,
14
+ UnstructuredPowerPointLoader,
15
+ UnstructuredWordDocumentLoader,
16
+ )
17
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
18
+ from langchain.vectorstores import Chroma
19
+ from langchain.embeddings import HuggingFaceEmbeddings
20
+ from langchain.docstore.document import Document
21
+ from chromadb.config import Settings
22
+ from llama_cpp import Llama
23
+
24
+
25
+ SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
26
+ SYSTEM_TOKEN = 1788
27
+ USER_TOKEN = 1404
28
+ BOT_TOKEN = 9225
29
+ LINEBREAK_TOKEN = 13
30
+
31
+ ROLE_TOKENS = {
32
+ "user": USER_TOKEN,
33
+ "bot": BOT_TOKEN,
34
+ "system": SYSTEM_TOKEN
35
+ }
36
+
37
+ LOADER_MAPPING = {
38
+ ".csv": (CSVLoader, {}),
39
+ ".doc": (UnstructuredWordDocumentLoader, {}),
40
+ ".docx": (UnstructuredWordDocumentLoader, {}),
41
+ ".enex": (EverNoteLoader, {}),
42
+ ".epub": (UnstructuredEPubLoader, {}),
43
+ ".html": (UnstructuredHTMLLoader, {}),
44
+ ".md": (UnstructuredMarkdownLoader, {}),
45
+ ".odt": (UnstructuredODTLoader, {}),
46
+ ".pdf": (PDFMinerLoader, {}),
47
+ ".ppt": (UnstructuredPowerPointLoader, {}),
48
+ ".pptx": (UnstructuredPowerPointLoader, {}),
49
+ ".txt": (TextLoader, {"encoding": "utf8"}),
50
+ }
51
+
52
+
53
+ MODEL_NAME = "ggml-model-q4_1.bin"
54
+ snapshot_download(
55
+ repo_id="IlyaGusev/saiga_7b_lora_llamacpp",
56
+ local_dir=".",
57
+ allow_patterns=MODEL_NAME
58
+ )
59
+
60
+
61
+ model = Llama(
62
+ model_path=MODEL_NAME,
63
+ n_ctx=2000,
64
+ n_parts=1,
65
+ )
66
+
67
+ max_new_tokens = 1500
68
+ top_k = 30
69
+ top_p = 0.9
70
+ temp = 0.1
71
+ repeat_penalty = 1.15
72
+ chunk_size = 300
73
+
74
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
75
+
76
+
77
+ def load_single_document(file_path: str) -> Document:
78
+ ext = "." + file_path.rsplit(".", 1)[-1]
79
+ assert ext in LOADER_MAPPING
80
+ loader_class, loader_args = LOADER_MAPPING[ext]
81
+ loader = loader_class(file_path, **loader_args)
82
+ return loader.load()[0]
83
+
84
+
85
+ def get_message_tokens(model, role, content):
86
+ message_tokens = model.tokenize(content.encode("utf-8"))
87
+ message_tokens.insert(1, ROLE_TOKENS[role])
88
+ message_tokens.insert(2, LINEBREAK_TOKEN)
89
+ message_tokens.append(model.token_eos())
90
+ return message_tokens
91
+
92
+
93
+ def get_system_tokens(model):
94
+ system_message = {"role": "system", "content": SYSTEM_PROMPT}
95
+ return get_message_tokens(model, **system_message)
96
+
97
+
98
+ def upload_files(files, file_paths):
99
+ file_paths = [f.name for f in files]
100
+ return file_paths
101
+
102
+
103
+ def build_index(file_paths, db):
104
+ documents = [load_single_document(path) for path in file_paths]
105
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=20)
106
+ texts = text_splitter.split_documents(documents)
107
+ def fix_lines(text):
108
+ lines = text.split("\n")
109
+ lines = [line for line in lines if len(line.strip()) > 2]
110
+ return "\n".join(lines)
111
+ fixed_texts = []
112
+ for text in texts:
113
+ text.page_content = fix_lines(text.page_content)
114
+ if len(text.page_content) < 10:
115
+ continue
116
+ fixed_texts.append(text)
117
+
118
+ db = Chroma.from_documents(
119
+ fixed_texts,
120
+ embeddings,
121
+ client_settings=Settings(
122
+ anonymized_telemetry=False
123
+ )
124
+ )
125
+ return db
126
+
127
+
128
+ def user(message, history, system_prompt):
129
+ new_history = history + [[message, None]]
130
+ return "", new_history
131
+
132
+
133
+ def bot(history, system_prompt, conversation_id, db):
134
+ if not history:
135
+ return
136
+
137
+ tokens = get_system_tokens(model)[:]
138
+ tokens.append(LINEBREAK_TOKEN)
139
+
140
+ for user_message, bot_message in history[:-1]:
141
+ message_tokens = get_message_tokens(model=model, role="user", content=user_message)
142
+ tokens.extend(message_tokens)
143
+ if bot_message:
144
+ message_tokens = get_message_tokens(model=model, role="bot", content=bot_message)
145
+ tokens.extend(message_tokens)
146
+
147
+ last_user_message = history[-1][0]
148
+ if db:
149
+ retriever = db.as_retriever(search_kwargs={"k": 2})
150
+ docs = retriever.get_relevant_documents(last_user_message)
151
+ context = "\n\n".join([doc.page_content for doc in docs])
152
+ last_user_message = f"Контекст: {context}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}"
153
+ message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
154
+ tokens.extend(message_tokens)
155
+
156
+ role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
157
+ tokens.extend(role_tokens)
158
+ generator = model.generate(
159
+ tokens,
160
+ top_k=top_k,
161
+ top_p=top_p,
162
+ temp=temp,
163
+ repeat_penalty=repeat_penalty
164
+ )
165
+
166
+ completion_tokens = []
167
+ partial_text = ""
168
+ for i, token in enumerate(generator):
169
+ completion_tokens.append(token)
170
+ if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens):
171
+ break
172
+ partial_text = model.detokenize(completion_tokens).decode("utf-8", "ignore")
173
+ history[-1][1] = partial_text
174
+ yield history
175
+
176
+
177
+ with gr.Blocks(
178
+ theme=gr.themes.Soft()
179
+ ) as demo:
180
+ db = gr.State(None)
181
+ conversation_id = gr.State(get_uuid)
182
+ favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
183
+ gr.Markdown(
184
+ f"""<h1><center>{favicon}Saiga 7B Retrieval QA Llama.cpp</center></h1>
185
+ """
186
+ )
187
+
188
+ system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT)
189
+
190
+ file_output = gr.File(file_count="multiple")
191
+ file_paths = gr.State([])
192
+
193
+ chatbot = gr.Chatbot().style(height=400)
194
+ with gr.Row():
195
+ with gr.Column():
196
+ msg = gr.Textbox(
197
+ label="Отправить сообщение",
198
+ placeholder="Отправить сообщение",
199
+ show_label=False,
200
+ ).style(container=False)
201
+ with gr.Column():
202
+ with gr.Row():
203
+ submit = gr.Button("Отправить")
204
+ stop = gr.Button("Остановить")
205
+ clear = gr.Button("Очистить")
206
+
207
+ upload_event = file_output.change(
208
+ fn=upload_files,
209
+ inputs=[file_output, file_paths],
210
+ outputs=[file_paths],
211
+ queue=False,
212
+ ).then(
213
+ fn=build_index,
214
+ inputs=[file_paths, db],
215
+ outputs=[db],
216
+ queue=True
217
+ )
218
+
219
+ submit_event = msg.submit(
220
+ fn=user,
221
+ inputs=[msg, chatbot, system_prompt],
222
+ outputs=[msg, chatbot],
223
+ queue=False,
224
+ ).then(
225
+ fn=bot,
226
+ inputs=[chatbot, system_prompt, conversation_id, db],
227
+ outputs=chatbot,
228
+ queue=True,
229
+ )
230
+
231
+ submit_click_event = submit.click(
232
+ fn=user,
233
+ inputs=[msg, chatbot, system_prompt],
234
+ outputs=[msg, chatbot],
235
+ queue=False,
236
+ ).then(
237
+ fn=bot,
238
+ inputs=[chatbot, system_prompt, conversation_id, db],
239
+ outputs=chatbot,
240
+ queue=True,
241
+ )
242
+ stop.click(
243
+ fn=None,
244
+ inputs=None,
245
+ outputs=None,
246
+ cancels=[submit_event, submit_click_event],
247
+ queue=False,
248
+ )
249
+ clear.click(lambda: None, None, chatbot, queue=False)
250
+
251
+ demo.queue(max_size=128, concurrency_count=1)
252
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ llama-cpp-python==0.1.53
2
+ langchain==0.0.174
3
+ huggingface-hub==0.14.1
4
+ chromadb=0.3.23
5
+ pdfminer.six==20221105