import fire from llama_cpp import Llama MESSAGE_TEMPLATE = "{role}\n{content}" SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им." BOT_TOKEN = 9225 LINEBREAK_TOKEN = 13 def get_message_tokens(model, role, content): message_text = MESSAGE_TEMPLATE.format(role=role, content=content) message_tokens = model.tokenize(message_text.encode("utf-8")) message_tokens.append(model.token_eos()) return message_tokens def get_system_tokens(model): system_message = { "role": "system", "content": SYSTEM_PROMPT } return get_message_tokens(model, **system_message) def interact( model_path, n_ctx=2000, top_k=30, top_p=0.9, temperature=0.2, repeat_penalty=1.1 ): model = Llama( model_path=model_path, n_ctx=n_ctx, n_parts=1, ) system_tokens = get_system_tokens(model) tokens = system_tokens model.eval(tokens) while True: user_message = input("User: ") message_tokens = get_message_tokens(model=model, role="user", content=user_message) role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] tokens += message_tokens + role_tokens generator = model.generate( tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repeat_penalty ) for token in generator: token_str = model.detokenize([token]).decode("utf-8") tokens.append(token) if token == model.token_eos(): break print(token_str, end="", flush=True) print() if __name__ == "__main__": fire.Fire(interact)