namespace-Pt's picture
Upload folder using huggingface_hub
612696d verified
|
raw
history blame
5.03 kB
metadata
license: mit
pipeline_tag: text-generation

Llama-3-8B-Instruct-80K-QLoRA

[Data&Code]

We extend the context length of Llama-3-8B-Instruct to 80K using QLoRA and 3.5K long-context training data synthesized from GPT-4. The entire training cycle is super efficient, which takes 8 hours on a 8xA800 (80G) machine. Yet, the resulted model achieves remarkable performance on a series of downstream long-context evaluation benchmarks.

Evaluation

All the following evaluation results can be reproduced following instructions here.

Needle in a Haystack

We evaluate the model on the Needle-In-A-HayStack task using the official setting.

LongBench

We evaluate the model on LongBench using 32K context length and the official prompt template. For meta-llama/Meta-Llama-3-8B-Instruct, we use 8K context length.

Model Single-Doc QA Multi-Doc QA Summarization Few-Shot Learning
meta-llama/Meta-Llama-3-8B-Instruct 37.33 36.04 26.83 69.56
gradientai/Llama-3-8B-Instruct-262k 37.29 31.20 26.18 67.25
Llama-3-8B-Instruct-80K-QLoRA 43.57 43.07 28.93 69.15

InfiniteBench

We evaluate the model on InfiniteBench using 80K context length and the official prompt template. The results of GPT4 is copied from the paper. For meta-llama/Meta-Llama-3-8B-Instruct, we use 8K context length.

Topic Retrieval

We evaluate the model on Topic Retrieval task with [5,10,15,20,25,30,40,50,60,70] topics.

MMLU

We evaluate the model's zero-shot performance on MMLU benchmark as a reflection of its short-context capability.

Model STEM Social Sciences Humanities Others Avg
meta-llama/Meta-Llama-3-8B-Instruct 53.87 75.66 69.44 69.75 65.91
gradientai/Llama-3-8B-Instruct-262k 52.10 73.26 67.15 69.80 64.34
Llama-3-8B-Instruct-80K-QLoRA 53.10 73.24 67.32 68.79 64.44

Environment

torch==2.2.2
flash_attn==2.5.6
transformers==4.39.3
peft==0.10.0

Usage

import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
peft_id = "namespace-Pt/Llama-3-8B-Instruct-80K-QLoRA"

torch_dtype = torch.bfloat16
# place the model on GPU
device_map = {"": "cuda"}

tokenizer = AutoTokenizer.from_pretrained(model_id)

base_model = AutoModelForCausalLM.from_pretrained(
  model_id, 
  torch_dtype=torch.bfloat16,
  device_map=device_map,
  attn_implementation="flash_attention_2",

  # NOTE: expand rope base
  rope_theta=200e6,
  max_position_embeddings=81920,
)

model = PeftModel.from_pretrained(
    base_model, 
    peft_id,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
)
# NOTE: merge LoRA weights
model = model.merge_and_unload().eval()

with torch.no_grad():
  # short context
  messages = [{"role": "user", "content": "Tell me about yourself."}]
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
  outputs = model.generate(**inputs, max_new_tokens=50)[:, inputs["input_ids"].shape[1]:]
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
  print(f"Output:       {tokenizer.decode(outputs[0])}")

  # long context
  with open("data/narrativeqa.json", encoding="utf-8") as f:
    example = json.load(f)
  messages = [{"role": "user", "content": example["context"]}]
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
  outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
  print("*"*20)
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
  print(f"Answers:      {example['answer']}")
  print(f"Prediction:   {tokenizer.decode(outputs[0])}")