File size: 1,934 Bytes
46085bb db1b241 46085bb db1b241 46085bb db1b241 46085bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import fire
from llama_cpp import Llama
SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
SYSTEM_TOKEN = 1788
USER_TOKEN = 1404
BOT_TOKEN = 9225
LINEBREAK_TOKEN = 13
ROLE_TOKENS = {
"user": USER_TOKEN,
"bot": BOT_TOKEN,
"system": SYSTEM_TOKEN
}
def get_message_tokens(model, role, content):
message_tokens = model.tokenize(content.encode("utf-8"))
message_tokens.insert(1, ROLE_TOKENS[role])
message_tokens.insert(2, LINEBREAK_TOKEN)
message_tokens.append(model.token_eos())
return message_tokens
def get_system_tokens(model):
system_message = {
"role": "system",
"content": SYSTEM_PROMPT
}
return get_message_tokens(model, **system_message)
def interact(
model_path,
n_ctx=2000,
top_k=30,
top_p=0.9,
temperature=0.2,
repeat_penalty=1.1
):
model = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_parts=1,
)
system_tokens = get_system_tokens(model)
tokens = system_tokens
model.eval(tokens)
while True:
user_message = input("User: ")
message_tokens = get_message_tokens(model=model, role="user", content=user_message)
role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
tokens += message_tokens + role_tokens
generator = model.generate(
tokens,
top_k=top_k,
top_p=top_p,
temp=temperature,
repeat_penalty=repeat_penalty
)
for token in generator:
token_str = model.detokenize([token]).decode("utf-8")
tokens.append(token)
if token == model.token_eos():
break
print(token_str, end="", flush=True)
print()
if __name__ == "__main__":
fire.Fire(interact)
|