lyraLLaMA / demo.py
yibolu
commit message
dae7d0f
raw
history blame contribute delete
559 Bytes
from lyra_llama import lyraLLaMA
model_path = "./models/lamma-13b-1-gpu-fp16.bin"
tokenizer_path = "./models/"
dtype='fp16'
prompt = "列出3个不同的机器学习算法,并说明它们的适用范围"
max_output_length = 512
model = lyraLLaMA(model_path, tokenizer_path, dtype)
prompt = '<human>:' + prompt.strip() + '\n<bot>:'
bs = 1
prompts = [prompt, ] * bs
output_texts = model.generate(
prompts, output_length=max_output_length,
top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False)
print(output_texts)