|
import shutil |
|
import os |
|
import gradio as gr |
|
|
|
import torch |
|
from uuid import uuid4 |
|
from huggingface_hub.file_download import http_get |
|
from langchain.document_loaders import ( |
|
CSVLoader, |
|
EverNoteLoader, |
|
PDFMinerLoader, |
|
TextLoader, |
|
UnstructuredEmailLoader, |
|
UnstructuredEPubLoader, |
|
UnstructuredHTMLLoader, |
|
UnstructuredMarkdownLoader, |
|
UnstructuredODTLoader, |
|
UnstructuredPowerPointLoader, |
|
UnstructuredWordDocumentLoader, |
|
) |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.docstore.document import Document |
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers.util import cos_sim |
|
from llama_cpp import Llama |
|
|
|
|
|
SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им." |
|
|
|
LOADER_MAPPING = { |
|
".csv": (CSVLoader, {}), |
|
".doc": (UnstructuredWordDocumentLoader, {}), |
|
".docx": (UnstructuredWordDocumentLoader, {}), |
|
".enex": (EverNoteLoader, {}), |
|
".epub": (UnstructuredEPubLoader, {}), |
|
".html": (UnstructuredHTMLLoader, {}), |
|
".md": (UnstructuredMarkdownLoader, {}), |
|
".odt": (UnstructuredODTLoader, {}), |
|
".pdf": (PDFMinerLoader, {}), |
|
".ppt": (UnstructuredPowerPointLoader, {}), |
|
".pptx": (UnstructuredPowerPointLoader, {}), |
|
".txt": (TextLoader, {"encoding": "utf8"}), |
|
} |
|
|
|
|
|
def load_model( |
|
directory: str = ".", |
|
model_name: str = "model-q4_K.gguf", |
|
model_url: str = "https://huggingface.co/IlyaGusev/saiga2_13b_gguf/resolve/main/model-q4_K.gguf" |
|
): |
|
final_model_path = os.path.join(directory, model_name) |
|
|
|
print("Downloading all files...") |
|
if not os.path.exists(final_model_path): |
|
with open(final_model_path, "wb") as f: |
|
http_get(model_url, f) |
|
os.chmod(final_model_path, 0o777) |
|
print("Files downloaded!") |
|
|
|
model = Llama( |
|
model_path=final_model_path, |
|
n_ctx=2000, |
|
n_parts=1, |
|
) |
|
|
|
print("Model loaded!") |
|
return model |
|
|
|
|
|
MAX_NEW_TOKENS = 1500 |
|
EMBEDDER = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2") |
|
MODEL = load_model() |
|
|
|
|
|
def get_uuid(): |
|
return str(uuid4()) |
|
|
|
|
|
def load_single_document(file_path: str) -> Document: |
|
ext = "." + file_path.rsplit(".", 1)[-1] |
|
assert ext in LOADER_MAPPING |
|
loader_class, loader_args = LOADER_MAPPING[ext] |
|
loader = loader_class(file_path, **loader_args) |
|
return loader.load()[0] |
|
|
|
|
|
def get_message_tokens(model, role, content): |
|
content = f"{role}\n{content}\n</s>" |
|
content = content.encode("utf-8") |
|
return model.tokenize(content, special=True) |
|
|
|
|
|
def get_system_tokens(model): |
|
system_message = {"role": "system", "content": SYSTEM_PROMPT} |
|
return get_message_tokens(model, **system_message) |
|
|
|
|
|
def process_text(text): |
|
lines = text.split("\n") |
|
lines = [line for line in lines if len(line.strip()) > 2] |
|
text = "\n".join(lines).strip() |
|
if len(text) < 10: |
|
return None |
|
return text |
|
|
|
|
|
def upload_files(files, file_paths): |
|
file_paths = [f.name for f in files] |
|
return file_paths |
|
|
|
|
|
def build_index(file_paths, db, chunk_size, chunk_overlap, file_warning): |
|
documents = [load_single_document(path) for path in file_paths] |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
|
documents = text_splitter.split_documents(documents) |
|
print("Documents after split:", len(documents)) |
|
fixed_documents = [] |
|
for doc in documents: |
|
doc.page_content = process_text(doc.page_content) |
|
if not doc.page_content: |
|
continue |
|
fixed_documents.append(doc) |
|
print("Documents after processing:", len(fixed_documents)) |
|
|
|
texts = [doc.page_content for doc in fixed_documents] |
|
embeddings = EMBEDDER.encode(texts, convert_to_tensor=True) |
|
db = {"docs": texts, "embeddings": embeddings} |
|
print("Embeddings calculated!") |
|
|
|
file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы." |
|
return db, file_warning |
|
|
|
|
|
def retrieve(history, db, retrieved_docs, k_documents): |
|
retrieved_docs = "" |
|
if db: |
|
last_user_message = history[-1][0] |
|
query_embedding = EMBEDDER.encode(last_user_message, convert_to_tensor=True) |
|
scores = cos_sim(query_embedding, db["embeddings"])[0] |
|
top_k_idx = torch.topk(scores, k=k_documents)[1] |
|
top_k_documents = [db["docs"][idx] for idx in top_k_idx] |
|
retrieved_docs = "\n\n".join(top_k_documents) |
|
return retrieved_docs |
|
|
|
|
|
def user(message, history, system_prompt): |
|
new_history = history + [[message, None]] |
|
return "", new_history |
|
|
|
|
|
def bot( |
|
history, |
|
system_prompt, |
|
conversation_id, |
|
retrieved_docs, |
|
top_p, |
|
top_k, |
|
temp |
|
): |
|
if not history: |
|
return |
|
|
|
tokens = get_system_tokens(MODEL)[:] |
|
tokens.append(LINEBREAK_TOKEN) |
|
|
|
for user_message, bot_message in history[:-1]: |
|
message_tokens = get_message_tokens(model=MODEL, role="user", content=user_message) |
|
tokens.extend(message_tokens) |
|
if bot_message: |
|
message_tokens = get_message_tokens(model=MODEL, role="bot", content=bot_message) |
|
tokens.extend(message_tokens) |
|
|
|
last_user_message = history[-1][0] |
|
if retrieved_docs: |
|
last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}" |
|
message_tokens = get_message_tokens(model=MODEL, role="user", content=last_user_message) |
|
tokens.extend(message_tokens) |
|
|
|
role_tokens = [MODEL.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] |
|
tokens.extend(role_tokens) |
|
generator = MODEL.generate( |
|
tokens, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temp=temp |
|
) |
|
|
|
partial_text = "" |
|
for i, token in enumerate(generator): |
|
if token == MODEL.token_eos() or (MAX_NEW_TOKENS is not None and i >= MAX_NEW_TOKENS): |
|
break |
|
partial_text += MODEL.detokenize([token]).decode("utf-8", "ignore") |
|
history[-1][1] = partial_text |
|
yield history |
|
|
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft() |
|
) as demo: |
|
db = gr.State(None) |
|
conversation_id = gr.State(get_uuid) |
|
favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">' |
|
gr.Markdown( |
|
f"""<h1><center>{favicon}Saiga 13B llama.cpp: retrieval QA</center></h1> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
file_output = gr.File(file_count="multiple", label="Загрузка файлов") |
|
file_paths = gr.State([]) |
|
file_warning = gr.Markdown(f"Фрагменты ещё не загружены!") |
|
|
|
with gr.Column(min_width=200, scale=3): |
|
with gr.Tab(label="Параметры нарезки"): |
|
chunk_size = gr.Slider( |
|
minimum=50, |
|
maximum=2000, |
|
value=250, |
|
step=50, |
|
interactive=True, |
|
label="Размер фрагментов", |
|
) |
|
chunk_overlap = gr.Slider( |
|
minimum=0, |
|
maximum=500, |
|
value=30, |
|
step=10, |
|
interactive=True, |
|
label="Пересечение" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
k_documents = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=2, |
|
step=1, |
|
interactive=True, |
|
label="Кол-во фрагментов для контекста" |
|
) |
|
with gr.Row(): |
|
retrieved_docs = gr.Textbox( |
|
lines=6, |
|
label="Извлеченные фрагменты", |
|
placeholder="Появятся после задавания вопросов", |
|
interactive=False |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False) |
|
chatbot = gr.Chatbot(label="Диалог").style(height=400) |
|
with gr.Column(min_width=80, scale=1): |
|
with gr.Tab(label="Параметры генерации"): |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
interactive=True, |
|
label="Top-p", |
|
) |
|
top_k = gr.Slider( |
|
minimum=10, |
|
maximum=100, |
|
value=30, |
|
step=5, |
|
interactive=True, |
|
label="Top-k", |
|
) |
|
temp = gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
value=0.1, |
|
step=0.1, |
|
interactive=True, |
|
label="Temp" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
msg = gr.Textbox( |
|
label="Отправить сообщение", |
|
placeholder="Отправить сообщение", |
|
show_label=False, |
|
).style(container=False) |
|
with gr.Column(): |
|
with gr.Row(): |
|
submit = gr.Button("Отправить") |
|
stop = gr.Button("Остановить") |
|
clear = gr.Button("Очистить") |
|
|
|
|
|
upload_event = file_output.change( |
|
fn=upload_files, |
|
inputs=[file_output, file_paths], |
|
outputs=[file_paths], |
|
queue=True, |
|
).success( |
|
fn=build_index, |
|
inputs=[file_paths, db, chunk_size, chunk_overlap, file_warning], |
|
outputs=[db, file_warning], |
|
queue=True |
|
) |
|
|
|
|
|
submit_event = msg.submit( |
|
fn=user, |
|
inputs=[msg, chatbot, system_prompt], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=retrieve, |
|
inputs=[chatbot, db, retrieved_docs, k_documents], |
|
outputs=[retrieved_docs], |
|
queue=True, |
|
).success( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
system_prompt, |
|
conversation_id, |
|
retrieved_docs, |
|
top_p, |
|
top_k, |
|
temp |
|
], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
submit_click_event = submit.click( |
|
fn=user, |
|
inputs=[msg, chatbot, system_prompt], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=retrieve, |
|
inputs=[chatbot, db, retrieved_docs, k_documents], |
|
outputs=[retrieved_docs], |
|
queue=True, |
|
).success( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
system_prompt, |
|
conversation_id, |
|
retrieved_docs, |
|
top_p, |
|
top_k, |
|
temp |
|
], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
|
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue(max_size=128, concurrency_count=1) |
|
demo.launch(show_error=True) |
|
|