File size: 5,482 Bytes
e24ecf0
 
0a05167
e24ecf0
0a05167
 
 
 
0a717d4
0a05167
 
1e36ec7
0a05167
 
 
 
 
 
 
d9b60fd
0a05167
 
 
 
 
 
 
0851965
 
0a717d4
0851965
0a717d4
0a05167
 
 
 
 
 
 
 
 
0851965
0a05167
 
 
 
1e36ec7
0a05167
 
 
 
 
1e36ec7
 
0851965
612696d
 
0a05167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc68800
0a05167
 
 
 
 
cc68800
0a05167
 
 
 
cc68800
 
0a05167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc68800
0a05167
cc68800
0a05167
 
 
 
 
 
 
 
 
 
cc68800
0a05167
0a717d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
---
license: mit
pipeline_tag: text-generation
---

<div align="center">
<h1>Llama-3-8B-Instruct-80K-QLoRA</h1>

<a href="https://github.com/FlagOpen/FlagEmbedding/tree/master/Long_LLM/">[Data&Code]</a>
</div>

We extend the context length of Llama-3-8B-Instruct to 80K using QLoRA and 3.5K long-context training data synthesized from GPT-4. The entire training cycle is super efficient, which takes 8 hours on a 8xA800 (80G) machine. Yet, the resulted model achieves remarkable performance on a series of downstream long-context evaluation benchmarks.


# Evaluation

All the following evaluation results can be reproduced following instructions [here](https://github.com/FlagOpen/FlagEmbedding/tree/master/Long_LLM/activation_beacon/new/docs/llama3-8b-instruct-qlora-80k.md).

## Needle in a Haystack
We evaluate the model on the Needle-In-A-HayStack task using the official setting. The blue vertical line indicates the training context length, i.e. 80K. 

<img src="data/needle.png"></img>


## LongBench
We evaluate the model on [LongBench](https://arxiv.org/abs/2308.14508) using 32K context length and the official prompt template. For [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), we use 8K context length.

|Model|Single-Doc QA|Multi-Doc QA|Summarization|Few-Shot Learning|Synthetic|Code|
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|[meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)|37.33|36.04|26.83|**69.56**|37.75|53.24|
|[gradientai/Llama-3-8B-Instruct-262k](https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k)|37.29|31.20|26.18|67.25|44.25|**62.71**|
|[Llama-3-8B-Instruct-80K-QLoRA]()|**43.57**|**43.07**|**28.93**|69.15|**48.50**|51.95|

## InfiniteBench
We evaluate the model on [InfiniteBench](https://arxiv.org/pdf/2402.13718.pdf) using 80K context length and the official prompt template. The results of GPT4 is copied from the [paper](https://arxiv.org/pdf/2402.13718.pdf). For [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), we use 8K context length.

|Model|LongBookQA Eng|
|:-:|:-:|
|GPT4|22.22|
|[meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)|7.00|
|[gradientai/Llama-3-8B-Instruct-262k](https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k)|20.30|
|[Llama-3-8B-Instruct-80K-QLoRA]()|**30.92**|

## Topic Retrieval
We evaluate the model on [Topic Retrieval](https://lmsys.org/blog/2023-06-29-longchat/) task with `[5,10,15,20,25,30,40,50,60,70]` topics.

<img src="data/topic_retrieval.png"></img>


## MMLU
We evaluate the model's zero-shot performance on MMLU benchmark as a reflection of its short-context capability.

|Model|STEM|Social Sciences|Humanities|Others|Avg|
|:-:|:-:|:-:|:-:|:-:|:-:|
|[meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)|**53.87**|**75.66**|**69.44**|**69.75**|**65.91**|
|[gradientai/Llama-3-8B-Instruct-262k](https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k)|52.10|73.26|67.15|69.80|64.34|
|[Llama-3-8B-Instruct-80K-QLoRA]()|53.10|73.24|67.32|68.79|64.44|

# Environment
```bash
torch==2.2.2
flash_attn==2.5.6
transformers==4.39.3
peft==0.10.0
```

# Usage
```python
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
peft_id = "namespace-Pt/Llama-3-8B-Instruct-80K-QLoRA"

torch_dtype = torch.bfloat16
# place the model on GPU
device_map = {"": "cuda"}

tokenizer = AutoTokenizer.from_pretrained(model_id)

base_model = AutoModelForCausalLM.from_pretrained(
  model_id, 
  torch_dtype=torch.bfloat16,
  device_map=device_map,
  attn_implementation="flash_attention_2",

  # NOTE: expand rope base
  rope_theta=200e6,
)

model = PeftModel.from_pretrained(
    base_model, 
    peft_id,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
)
# NOTE: merge LoRA weights
model = model.merge_and_unload().eval()

with torch.no_grad():
  # short context
  messages = [{"role": "user", "content": "Tell me about yourself."}]
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
  outputs = model.generate(**inputs, max_new_tokens=50)[:, inputs["input_ids"].shape[1]:]
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
  print(f"Output:       {tokenizer.decode(outputs[0])}")

  # long context
  with open("data/narrativeqa.json", encoding="utf-8") as f:
    example = json.load(f)
  messages = [{"role": "user", "content": example["context"]}]
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
  outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
  print("*"*20)
  print(f"Input Length: {inputs['input_ids'].shape[1]}")
  print(f"Answers:      {example['answer']}")
  print(f"Prediction:   {tokenizer.decode(outputs[0])}")
```
You may observe messages like:
`This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (8192). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.` or `Setting pad_token_id to eos_token_id:128001 for open-end generation`. They do not matter. Just ignore them.