IlyaGusev's picture
Initial commit
8e5a242
raw
history blame
7.62 kB
import gradio as gr
from huggingface_hub import snapshot_download
from langchain.document_loaders import (
CSVLoader,
EverNoteLoader,
PDFMinerLoader,
TextLoader,
UnstructuredEmailLoader,
UnstructuredEPubLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from chromadb.config import Settings
from llama_cpp import Llama
SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
SYSTEM_TOKEN = 1788
USER_TOKEN = 1404
BOT_TOKEN = 9225
LINEBREAK_TOKEN = 13
ROLE_TOKENS = {
"user": USER_TOKEN,
"bot": BOT_TOKEN,
"system": SYSTEM_TOKEN
}
LOADER_MAPPING = {
".csv": (CSVLoader, {}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
}
MODEL_NAME = "ggml-model-q4_1.bin"
snapshot_download(
repo_id="IlyaGusev/saiga_7b_lora_llamacpp",
local_dir=".",
allow_patterns=MODEL_NAME
)
model = Llama(
model_path=MODEL_NAME,
n_ctx=2000,
n_parts=1,
)
max_new_tokens = 1500
top_k = 30
top_p = 0.9
temp = 0.1
repeat_penalty = 1.15
chunk_size = 300
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
def load_single_document(file_path: str) -> Document:
ext = "." + file_path.rsplit(".", 1)[-1]
assert ext in LOADER_MAPPING
loader_class, loader_args = LOADER_MAPPING[ext]
loader = loader_class(file_path, **loader_args)
return loader.load()[0]
def get_message_tokens(model, role, content):
message_tokens = model.tokenize(content.encode("utf-8"))
message_tokens.insert(1, ROLE_TOKENS[role])
message_tokens.insert(2, LINEBREAK_TOKEN)
message_tokens.append(model.token_eos())
return message_tokens
def get_system_tokens(model):
system_message = {"role": "system", "content": SYSTEM_PROMPT}
return get_message_tokens(model, **system_message)
def upload_files(files, file_paths):
file_paths = [f.name for f in files]
return file_paths
def build_index(file_paths, db):
documents = [load_single_document(path) for path in file_paths]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=20)
texts = text_splitter.split_documents(documents)
def fix_lines(text):
lines = text.split("\n")
lines = [line for line in lines if len(line.strip()) > 2]
return "\n".join(lines)
fixed_texts = []
for text in texts:
text.page_content = fix_lines(text.page_content)
if len(text.page_content) < 10:
continue
fixed_texts.append(text)
db = Chroma.from_documents(
fixed_texts,
embeddings,
client_settings=Settings(
anonymized_telemetry=False
)
)
return db
def user(message, history, system_prompt):
new_history = history + [[message, None]]
return "", new_history
def bot(history, system_prompt, conversation_id, db):
if not history:
return
tokens = get_system_tokens(model)[:]
tokens.append(LINEBREAK_TOKEN)
for user_message, bot_message in history[:-1]:
message_tokens = get_message_tokens(model=model, role="user", content=user_message)
tokens.extend(message_tokens)
if bot_message:
message_tokens = get_message_tokens(model=model, role="bot", content=bot_message)
tokens.extend(message_tokens)
last_user_message = history[-1][0]
if db:
retriever = db.as_retriever(search_kwargs={"k": 2})
docs = retriever.get_relevant_documents(last_user_message)
context = "\n\n".join([doc.page_content for doc in docs])
last_user_message = f"Контекст: {context}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}"
message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
tokens.extend(message_tokens)
role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
tokens.extend(role_tokens)
generator = model.generate(
tokens,
top_k=top_k,
top_p=top_p,
temp=temp,
repeat_penalty=repeat_penalty
)
completion_tokens = []
partial_text = ""
for i, token in enumerate(generator):
completion_tokens.append(token)
if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens):
break
partial_text = model.detokenize(completion_tokens).decode("utf-8", "ignore")
history[-1][1] = partial_text
yield history
with gr.Blocks(
theme=gr.themes.Soft()
) as demo:
db = gr.State(None)
conversation_id = gr.State(get_uuid)
favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
gr.Markdown(
f"""<h1><center>{favicon}Saiga 7B Retrieval QA Llama.cpp</center></h1>
"""
)
system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT)
file_output = gr.File(file_count="multiple")
file_paths = gr.State([])
chatbot = gr.Chatbot().style(height=400)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Отправить сообщение",
placeholder="Отправить сообщение",
show_label=False,
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Отправить")
stop = gr.Button("Остановить")
clear = gr.Button("Очистить")
upload_event = file_output.change(
fn=upload_files,
inputs=[file_output, file_paths],
outputs=[file_paths],
queue=False,
).then(
fn=build_index,
inputs=[file_paths, db],
outputs=[db],
queue=True
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot, system_prompt],
outputs=[msg, chatbot],
queue=False,
).then(
fn=bot,
inputs=[chatbot, system_prompt, conversation_id, db],
outputs=chatbot,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot, system_prompt],
outputs=[msg, chatbot],
queue=False,
).then(
fn=bot,
inputs=[chatbot, system_prompt, conversation_id, db],
outputs=chatbot,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue(max_size=128, concurrency_count=1)
demo.launch()