IlyaGusev commited on
Commit
a6787f7
1 Parent(s): 258d8d2

Many interface changes

Browse files
Files changed (2) hide show
  1. README.md +1 -3
  2. app.py +163 -53
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Saiga 7b Retrieval Qa
3
  emoji: 🚀
4
  colorFrom: pink
5
  colorTo: pink
@@ -8,5 +8,3 @@ sdk_version: 3.32.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Saiga 7b llama.cpp Retrieval QA
3
  emoji: 🚀
4
  colorFrom: pink
5
  colorTo: pink
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
 
app.py CHANGED
@@ -51,28 +51,20 @@ LOADER_MAPPING = {
51
  }
52
 
53
 
54
- MODEL_NAME = "ggml-model-q4_1.bin"
55
- snapshot_download(
56
- repo_id="IlyaGusev/saiga_7b_lora_llamacpp",
57
- local_dir=".",
58
- allow_patterns=MODEL_NAME
59
- )
60
 
 
61
 
62
  model = Llama(
63
- model_path=MODEL_NAME,
64
  n_ctx=2000,
65
  n_parts=1,
66
  )
67
 
68
  max_new_tokens = 1500
69
- top_k = 30
70
- top_p = 0.9
71
- temp = 0.1
72
- repeat_penalty = 1.15
73
- chunk_size = 300
74
-
75
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
76
 
77
  def get_uuid():
78
  return str(uuid4())
@@ -104,29 +96,35 @@ def upload_files(files, file_paths):
104
  return file_paths
105
 
106
 
107
- def build_index(file_paths, db):
 
 
 
 
 
 
 
 
 
108
  documents = [load_single_document(path) for path in file_paths]
109
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=20)
110
- texts = text_splitter.split_documents(documents)
111
- def fix_lines(text):
112
- lines = text.split("\n")
113
- lines = [line for line in lines if len(line.strip()) > 2]
114
- return "\n".join(lines)
115
- fixed_texts = []
116
- for text in texts:
117
- text.page_content = fix_lines(text.page_content)
118
- if len(text.page_content) < 10:
119
  continue
120
- fixed_texts.append(text)
121
 
122
  db = Chroma.from_documents(
123
- fixed_texts,
124
  embeddings,
125
  client_settings=Settings(
126
  anonymized_telemetry=False
127
  )
128
  )
129
- return db
 
130
 
131
 
132
  def user(message, history, system_prompt):
@@ -134,7 +132,25 @@ def user(message, history, system_prompt):
134
  return "", new_history
135
 
136
 
137
- def bot(history, system_prompt, conversation_id, db):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  if not history:
139
  return
140
 
@@ -149,11 +165,8 @@ def bot(history, system_prompt, conversation_id, db):
149
  tokens.extend(message_tokens)
150
 
151
  last_user_message = history[-1][0]
152
- if db:
153
- retriever = db.as_retriever(search_kwargs={"k": 2})
154
- docs = retriever.get_relevant_documents(last_user_message)
155
- context = "\n\n".join([doc.page_content for doc in docs])
156
- last_user_message = f"Контекст: {context}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}"
157
  message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
158
  tokens.extend(message_tokens)
159
 
@@ -163,17 +176,14 @@ def bot(history, system_prompt, conversation_id, db):
163
  tokens,
164
  top_k=top_k,
165
  top_p=top_p,
166
- temp=temp,
167
- repeat_penalty=repeat_penalty
168
  )
169
 
170
- completion_tokens = []
171
  partial_text = ""
172
  for i, token in enumerate(generator):
173
- completion_tokens.append(token)
174
  if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens):
175
  break
176
- partial_text = model.detokenize(completion_tokens).decode("utf-8", "ignore")
177
  history[-1][1] = partial_text
178
  yield history
179
 
@@ -185,16 +195,83 @@ with gr.Blocks(
185
  conversation_id = gr.State(get_uuid)
186
  favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
187
  gr.Markdown(
188
- f"""<h1><center>{favicon}Saiga 7B Retrieval QA Llama.cpp</center></h1>
189
  """
190
  )
191
 
192
- system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- file_output = gr.File(file_count="multiple")
195
- file_paths = gr.State([])
196
 
197
- chatbot = gr.Chatbot().style(height=400)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  with gr.Row():
199
  with gr.Column():
200
  msg = gr.Textbox(
@@ -208,41 +285,72 @@ with gr.Blocks(
208
  stop = gr.Button("Остановить")
209
  clear = gr.Button("Очистить")
210
 
 
211
  upload_event = file_output.change(
212
  fn=upload_files,
213
  inputs=[file_output, file_paths],
214
  outputs=[file_paths],
215
- queue=False,
216
- ).then(
217
  fn=build_index,
218
- inputs=[file_paths, db],
219
- outputs=[db],
220
  queue=True
221
  )
222
 
 
223
  submit_event = msg.submit(
224
  fn=user,
225
  inputs=[msg, chatbot, system_prompt],
226
  outputs=[msg, chatbot],
227
  queue=False,
228
- ).then(
 
 
 
 
 
229
  fn=bot,
230
- inputs=[chatbot, system_prompt, conversation_id, db],
 
 
 
 
 
 
 
 
231
  outputs=chatbot,
232
  queue=True,
233
  )
234
 
 
235
  submit_click_event = submit.click(
236
  fn=user,
237
  inputs=[msg, chatbot, system_prompt],
238
  outputs=[msg, chatbot],
239
  queue=False,
240
- ).then(
 
 
 
 
 
241
  fn=bot,
242
- inputs=[chatbot, system_prompt, conversation_id, db],
 
 
 
 
 
 
 
 
243
  outputs=chatbot,
244
  queue=True,
245
  )
 
 
246
  stop.click(
247
  fn=None,
248
  inputs=None,
@@ -250,7 +358,9 @@ with gr.Blocks(
250
  cancels=[submit_event, submit_click_event],
251
  queue=False,
252
  )
 
 
253
  clear.click(lambda: None, None, chatbot, queue=False)
254
 
255
  demo.queue(max_size=128, concurrency_count=1)
256
- demo.launch()
 
51
  }
52
 
53
 
54
+ repo_name = "IlyaGusev/saiga_7b_lora_llamacpp"
55
+ model_name = "ggml-model-q8_0.bin"
56
+ embedder_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
 
 
 
57
 
58
+ snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name)
59
 
60
  model = Llama(
61
+ model_path=model_name,
62
  n_ctx=2000,
63
  n_parts=1,
64
  )
65
 
66
  max_new_tokens = 1500
67
+ embeddings = HuggingFaceEmbeddings(model_name=embedder_name)
 
 
 
 
 
 
68
 
69
  def get_uuid():
70
  return str(uuid4())
 
96
  return file_paths
97
 
98
 
99
+ def process_text(text):
100
+ lines = text.split("\n")
101
+ lines = [line for line in lines if len(line.strip()) > 2]
102
+ text = "\n".join(lines).strip()
103
+ if len(text) < 10:
104
+ return None
105
+ return text
106
+
107
+
108
+ def build_index(file_paths, db, chunk_size, chunk_overlap, file_warning):
109
  documents = [load_single_document(path) for path in file_paths]
110
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
111
+ documents = text_splitter.split_documents(documents)
112
+ fixed_documents = []
113
+ for doc in documents:
114
+ doc.page_content = process_text(doc.page_content)
115
+ if not doc.page_content:
 
 
 
 
116
  continue
117
+ fixed_documents.append(doc)
118
 
119
  db = Chroma.from_documents(
120
+ fixed_documents,
121
  embeddings,
122
  client_settings=Settings(
123
  anonymized_telemetry=False
124
  )
125
  )
126
+ file_warning = f"Загружен {len(fixed_documents)} фрагментов! Можно задавать вопросы."
127
+ return db, file_warning
128
 
129
 
130
  def user(message, history, system_prompt):
 
132
  return "", new_history
133
 
134
 
135
+ def retrieve(history, db, retrieved_docs, k_documents):
136
+ context = ""
137
+ if db:
138
+ last_user_message = history[-1][0]
139
+ retriever = db.as_retriever(search_kwargs={"k": k_documents})
140
+ docs = retriever.get_relevant_documents(last_user_message)
141
+ retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
142
+ return retrieved_docs
143
+
144
+
145
+ def bot(
146
+ history,
147
+ system_prompt,
148
+ conversation_id,
149
+ retrieved_docs,
150
+ top_p,
151
+ top_k,
152
+ temp
153
+ ):
154
  if not history:
155
  return
156
 
 
165
  tokens.extend(message_tokens)
166
 
167
  last_user_message = history[-1][0]
168
+ if retrieved_docs:
169
+ last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}"
 
 
 
170
  message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
171
  tokens.extend(message_tokens)
172
 
 
176
  tokens,
177
  top_k=top_k,
178
  top_p=top_p,
179
+ temp=temp
 
180
  )
181
 
 
182
  partial_text = ""
183
  for i, token in enumerate(generator):
 
184
  if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens):
185
  break
186
+ partial_text += model.detokenize([token]).decode("utf-8", "ignore")
187
  history[-1][1] = partial_text
188
  yield history
189
 
 
195
  conversation_id = gr.State(get_uuid)
196
  favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
197
  gr.Markdown(
198
+ f"""<h1><center>{favicon}Saiga 7B llama.cpp: retrieval QA</center></h1>
199
  """
200
  )
201
 
202
+ with gr.Row():
203
+ with gr.Column(scale=5):
204
+ file_output = gr.File(file_count="multiple", label="Загрузка файлов")
205
+ file_paths = gr.State([])
206
+ file_warning = gr.Markdown(f"Фрагменты ещё не загружены!")
207
+
208
+ with gr.Column(min_width=200, scale=3):
209
+ with gr.Tab(label="Параметры нарезки"):
210
+ chunk_size = gr.Slider(
211
+ minimum=50,
212
+ maximum=2000,
213
+ value=250,
214
+ step=50,
215
+ interactive=True,
216
+ label="Размер фрагментов",
217
+ )
218
+ chunk_overlap = gr.Slider(
219
+ minimum=0,
220
+ maximum=500,
221
+ value=30,
222
+ step=10,
223
+ interactive=True,
224
+ label="Пересечение"
225
+ )
226
 
 
 
227
 
228
+ with gr.Row():
229
+ k_documents = gr.Slider(
230
+ minimum=1,
231
+ maximum=10,
232
+ value=2,
233
+ step=1,
234
+ interactive=True,
235
+ label="Кол-во фрагментов для контекста"
236
+ )
237
+ with gr.Row():
238
+ retrieved_docs = gr.Textbox(
239
+ lines=6,
240
+ label="Извлеченные фрагменты",
241
+ placeholder="Появятся после задавания вопросов",
242
+ interactive=False
243
+ )
244
+ with gr.Row():
245
+ with gr.Column(scale=5):
246
+ system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False)
247
+ chatbot = gr.Chatbot(label="Диалог").style(height=400)
248
+ with gr.Column(min_width=80, scale=1):
249
+ with gr.Tab(label="Параметры генерации"):
250
+ top_p = gr.Slider(
251
+ minimum=0.0,
252
+ maximum=1.0,
253
+ value=0.9,
254
+ step=0.05,
255
+ interactive=True,
256
+ label="Top-p",
257
+ )
258
+ top_k = gr.Slider(
259
+ minimum=10,
260
+ maximum=100,
261
+ value=30,
262
+ step=5,
263
+ interactive=True,
264
+ label="Top-k",
265
+ )
266
+ temp = gr.Slider(
267
+ minimum=0.0,
268
+ maximum=2.0,
269
+ value=0.1,
270
+ step=0.1,
271
+ interactive=True,
272
+ label="Temp"
273
+ )
274
+
275
  with gr.Row():
276
  with gr.Column():
277
  msg = gr.Textbox(
 
285
  stop = gr.Button("Остановить")
286
  clear = gr.Button("Очистить")
287
 
288
+ # Upload files
289
  upload_event = file_output.change(
290
  fn=upload_files,
291
  inputs=[file_output, file_paths],
292
  outputs=[file_paths],
293
+ queue=True,
294
+ ).success(
295
  fn=build_index,
296
+ inputs=[file_paths, db, chunk_size, chunk_overlap, file_warning],
297
+ outputs=[db, file_warning],
298
  queue=True
299
  )
300
 
301
+ # Pressing Enter
302
  submit_event = msg.submit(
303
  fn=user,
304
  inputs=[msg, chatbot, system_prompt],
305
  outputs=[msg, chatbot],
306
  queue=False,
307
+ ).success(
308
+ fn=retrieve,
309
+ inputs=[chatbot, db, retrieved_docs, k_documents],
310
+ outputs=[retrieved_docs],
311
+ queue=True,
312
+ ).success(
313
  fn=bot,
314
+ inputs=[
315
+ chatbot,
316
+ system_prompt,
317
+ conversation_id,
318
+ retrieved_docs,
319
+ top_p,
320
+ top_k,
321
+ temp
322
+ ],
323
  outputs=chatbot,
324
  queue=True,
325
  )
326
 
327
+ # Pressing the button
328
  submit_click_event = submit.click(
329
  fn=user,
330
  inputs=[msg, chatbot, system_prompt],
331
  outputs=[msg, chatbot],
332
  queue=False,
333
+ ).success(
334
+ fn=retrieve,
335
+ inputs=[chatbot, db, retrieved_docs, k_documents],
336
+ outputs=[retrieved_docs],
337
+ queue=True,
338
+ ).success(
339
  fn=bot,
340
+ inputs=[
341
+ chatbot,
342
+ system_prompt,
343
+ conversation_id,
344
+ retrieved_docs,
345
+ top_p,
346
+ top_k,
347
+ temp
348
+ ],
349
  outputs=chatbot,
350
  queue=True,
351
  )
352
+
353
+ # Stop generation
354
  stop.click(
355
  fn=None,
356
  inputs=None,
 
358
  cancels=[submit_event, submit_click_event],
359
  queue=False,
360
  )
361
+
362
+ # Clear history
363
  clear.click(lambda: None, None, chatbot, queue=False)
364
 
365
  demo.queue(max_size=128, concurrency_count=1)
366
+ demo.launch(share=True)