import gradio as gr from huggingface_hub import snapshot_download from langchain.document_loaders import ( CSVLoader, EverNoteLoader, PDFMinerLoader, TextLoader, UnstructuredEmailLoader, UnstructuredEPubLoader, UnstructuredHTMLLoader, UnstructuredMarkdownLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, UnstructuredWordDocumentLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceEmbeddings from langchain.docstore.document import Document from chromadb.config import Settings from llama_cpp import Llama SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им." SYSTEM_TOKEN = 1788 USER_TOKEN = 1404 BOT_TOKEN = 9225 LINEBREAK_TOKEN = 13 ROLE_TOKENS = { "user": USER_TOKEN, "bot": BOT_TOKEN, "system": SYSTEM_TOKEN } LOADER_MAPPING = { ".csv": (CSVLoader, {}), ".doc": (UnstructuredWordDocumentLoader, {}), ".docx": (UnstructuredWordDocumentLoader, {}), ".enex": (EverNoteLoader, {}), ".epub": (UnstructuredEPubLoader, {}), ".html": (UnstructuredHTMLLoader, {}), ".md": (UnstructuredMarkdownLoader, {}), ".odt": (UnstructuredODTLoader, {}), ".pdf": (PDFMinerLoader, {}), ".ppt": (UnstructuredPowerPointLoader, {}), ".pptx": (UnstructuredPowerPointLoader, {}), ".txt": (TextLoader, {"encoding": "utf8"}), } MODEL_NAME = "ggml-model-q4_1.bin" snapshot_download( repo_id="IlyaGusev/saiga_7b_lora_llamacpp", local_dir=".", allow_patterns=MODEL_NAME ) model = Llama( model_path=MODEL_NAME, n_ctx=2000, n_parts=1, ) max_new_tokens = 1500 top_k = 30 top_p = 0.9 temp = 0.1 repeat_penalty = 1.15 chunk_size = 300 embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") def load_single_document(file_path: str) -> Document: ext = "." + file_path.rsplit(".", 1)[-1] assert ext in LOADER_MAPPING loader_class, loader_args = LOADER_MAPPING[ext] loader = loader_class(file_path, **loader_args) return loader.load()[0] def get_message_tokens(model, role, content): message_tokens = model.tokenize(content.encode("utf-8")) message_tokens.insert(1, ROLE_TOKENS[role]) message_tokens.insert(2, LINEBREAK_TOKEN) message_tokens.append(model.token_eos()) return message_tokens def get_system_tokens(model): system_message = {"role": "system", "content": SYSTEM_PROMPT} return get_message_tokens(model, **system_message) def upload_files(files, file_paths): file_paths = [f.name for f in files] return file_paths def build_index(file_paths, db): documents = [load_single_document(path) for path in file_paths] text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=20) texts = text_splitter.split_documents(documents) def fix_lines(text): lines = text.split("\n") lines = [line for line in lines if len(line.strip()) > 2] return "\n".join(lines) fixed_texts = [] for text in texts: text.page_content = fix_lines(text.page_content) if len(text.page_content) < 10: continue fixed_texts.append(text) db = Chroma.from_documents( fixed_texts, embeddings, client_settings=Settings( anonymized_telemetry=False ) ) return db def user(message, history, system_prompt): new_history = history + [[message, None]] return "", new_history def bot(history, system_prompt, conversation_id, db): if not history: return tokens = get_system_tokens(model)[:] tokens.append(LINEBREAK_TOKEN) for user_message, bot_message in history[:-1]: message_tokens = get_message_tokens(model=model, role="user", content=user_message) tokens.extend(message_tokens) if bot_message: message_tokens = get_message_tokens(model=model, role="bot", content=bot_message) tokens.extend(message_tokens) last_user_message = history[-1][0] if db: retriever = db.as_retriever(search_kwargs={"k": 2}) docs = retriever.get_relevant_documents(last_user_message) context = "\n\n".join([doc.page_content for doc in docs]) last_user_message = f"Контекст: {context}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}" message_tokens = get_message_tokens(model=model, role="user", content=last_user_message) tokens.extend(message_tokens) role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] tokens.extend(role_tokens) generator = model.generate( tokens, top_k=top_k, top_p=top_p, temp=temp, repeat_penalty=repeat_penalty ) completion_tokens = [] partial_text = "" for i, token in enumerate(generator): completion_tokens.append(token) if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens): break partial_text = model.detokenize(completion_tokens).decode("utf-8", "ignore") history[-1][1] = partial_text yield history with gr.Blocks( theme=gr.themes.Soft() ) as demo: db = gr.State(None) conversation_id = gr.State(get_uuid) favicon = '' gr.Markdown( f"""