metadata
license: llama3
datasets:
- yuyijiong/Long-Instruction-with-Paraphrasing
language:
- zh
- en
pipeline_tag: text-generation
Llama3-8b-chinese-chat-32k
训练方式
使用 NTK-aware 方法扩展上下文长度至 32k
以 shenzhi-wang/Llama3-8B-Chinese-Chat 为基础 在 Long-Instruction-with-Paraphrasing 数据集上,使用 QLora 微调 1 epoch。
使用方法
和原版相同
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "yuyijiong/Llama3-8B-Chinese-Chat-32k"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype="auto", device_map="auto"
)
messages = [
{"role": "user", "content": "写一首诗吧"},
]
input_ids = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=32768,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))
Long-Context Performance
相比原始版本,拥有更强的长上下文能力
LongBench (en)
model | hotpotqa | multifieldqa_en | passage_retrieval_en | qmsum | trec |
---|---|---|---|---|---|
llama3-8b-chinese-chat | 45.88 | 50.56 | 68.00 | 22.52 | 73.00 |
llama3-8b-chinese-chat-32k | 47.64 | 49.98 | 100.00 | 25.13 | 75.0 |
Longbench (zh)
model | dureader | multifieldqa_zh | passage_retrieval_zh | vcsum | lsht |
---|---|---|---|---|---|
llama3-8b-chinese-chat | 29.08 | 58.4 | 93.5 | 14.61 | 28.25 |
llama3-8b-chinese-chat-32k | 32.31 | 58.66 | 82.5 | 16.15 | 38.5 |