File size: 5,283 Bytes
d643284
 
 
 
 
 
ce46844
d643284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce46844
 
 
 
 
 
 
 
 
 
d643284
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
---
license: apache-2.0
---
# MoH: Multi-Head Attention as Mixture-of-Head Attention

**Paper or resources for more information:**
[[Paper]()] [[Code](https://github.com/SkyworkAI/MoH)]

## โšก Overview
We propose Mixture-of-Head attention (MoH), a new architecture that treats attention heads as experts in the Mixture-of-Experts (MoE) mechanism. MoH has two significant advantages:
* First, MoH enables each token to select the appropriate attention heads, enhancing inference efficiency without compromising accuracy or increasing the number of parameters. 
* Second, MoH replaces the standard summation in multi-head attention with a weighted summation, introducing flexibility to the attention mechanism and unlocking extra performance potential.



## ๐Ÿ˜ฎ Highlights
### ๐Ÿ’ก General Framework
We evaluate our proposed MoH across various popular model frameworks, including Vision Transformers (ViT) for image classification, Diffusion models with Transformers (DiT) for class-conditional image generation, and Large Language Models (LLMs) for language tasks.

<div align=center>

|                   Code                    |                                                                                                                         HuggingFace Model                                                                                                                         |  
|:-----------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
|     **[MoH-ViT](https://github.com/SkyworkAI/MoH/tree/main/MoH-ViT)**      | ๐Ÿค— [MoH-ViT-B-75](https://huggingface.co/Chat-UniVi/MoH-ViT-B-75), [MoH-ViT-B-50](https://huggingface.co/Chat-UniVi/MoH-ViT-B-50), [MoH-ViT-S-80](https://huggingface.co/Chat-UniVi/MoH-ViT-S-80), [MoH-ViT-S-75](https://huggingface.co/Chat-UniVi/MoH-ViT-S-75) |
|     **[MoH-DiT](https://github.com/SkyworkAI/MoH/tree/main/MoH-DiT)**      |                                                                                                 ๐Ÿ˜Š [MoH-DiT-90](https://huggingface.co/Chat-UniVi/MoH-DiT-XL-90)                                                                                                  | 
| **[MoH-LLaMA3-8B](https://github.com/SkyworkAI/MoH/tree/main/MoH-LLaMA3)** |                                                                                                                        ๐Ÿ˜Š [MoH-LLaMA3-8B](https://huggingface.co/Chat-UniVi/MoH-LLaMA3-8B)                                                                                                                         | 

</div>

### ๐Ÿ”ฅ High Performance
Extensive experiments on ViT, DiT, and LLMs demonstrate that MoH outperforms multi-head attention by using only **50%~90%** of the attention heads.

### ๐Ÿค— Support Continue-Tuning Starting from the Multi-Head Attention Models
we demonstrate that pre-trained multi-head attention models, such as LLaMA3-8B, can be further continue-tuned into our MoH models. Notably, MoH-LLaMA3-8B achieves an average accuracy of 64.0% across 14 benchmarks, outperforming LLaMA3-8B by 2.4% by utilizing only 75% of the attention heads.


The MoH model quickly recovers to over **95%** of the performance of the original model within a training budget of 10B tokens. Then, the performance gradually improves with the increase of the training tokens.

## ๐Ÿค– API for Model Inference
If you want to load the model from the model hub on Hugging Face or on local, you can use the following code snippets.

### Base Model Inference
```python
from transformers import AutoModelForCausalLM, AutoTokenizer

question = "Hello!"

model = AutoModelForCausalLM.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True)

inputs = tokenizer(question, return_tensors='pt').to(model.device)
response = model.generate(inputs.input_ids, max_length=128)
print(tokenizer.decode(response.cpu()[0], skip_special_tokens=True))
```

### Chat Model Inference
Coming soon...


## ๐Ÿ—๏ธ Training & Validating
* The training code is built on [Skywork-MoE](https://github.com/SkyworkAI/Skywork-MoE). Unless Skywork-MoE is open source, we can't open source MoH-LLaMA3 alone. We will release the training code after the approval is completed.
* The evaluation is performed on multiple key benchmarks using the [Eleuther AI Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).

```python
# For example, test MoH-LLaMA3-8B on winogrande

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
--main_process_port 2004 -m lm_eval --model hf \
--model_args pretrained=Chat-UniVi/MoH-LLaMA3-8B \
--tasks winogrande \
--batch_size 1 \
--output_path Results/winogrande
```

## โœ๏ธ Citation
If you find this paper useful, please consider staring ๐ŸŒŸ this repo and citing ๐Ÿ“‘ our paper:
```
@article{jin2024moh,
  title={MoH: Multi-Head Attention as Mixture-of-Head Attention}, 
  author={Peng Jin and Bo Zhu and Li Yuan and Shuicheng Yan},
  year={2024}
}
```