Edit model card

MoH: Multi-Head Attention as Mixture-of-Head Attention

Paper or resources for more information: [Paper] [Code]

โšก Overview

We propose Mixture-of-Head attention (MoH), a new architecture that treats attention heads as experts in the Mixture-of-Experts (MoE) mechanism. MoH has two significant advantages:

  • First, MoH enables each token to select the appropriate attention heads, enhancing inference efficiency without compromising accuracy or increasing the number of parameters.
  • Second, MoH replaces the standard summation in multi-head attention with a weighted summation, introducing flexibility to the attention mechanism and unlocking extra performance potential.

๐Ÿ˜ฎ Highlights

๐Ÿ’ก General Framework

We evaluate our proposed MoH across various popular model frameworks, including Vision Transformers (ViT) for image classification, Diffusion models with Transformers (DiT) for class-conditional image generation, and Large Language Models (LLMs) for language tasks.

Code HuggingFace Model
MoH-ViT ๐Ÿค— MoH-ViT-B-75, MoH-ViT-B-50, MoH-ViT-S-80, MoH-ViT-S-75
MoH-DiT ๐Ÿ˜Š MoH-DiT-90
MoH-LLaMA3-8B ๐Ÿ˜Š MoH-LLaMA3-8B

๐Ÿ”ฅ High Performance

Extensive experiments on ViT, DiT, and LLMs demonstrate that MoH outperforms multi-head attention by using only 50%~90% of the attention heads.

๐Ÿค— Support Continue-Tuning Starting from the Multi-Head Attention Models

we demonstrate that pre-trained multi-head attention models, such as LLaMA3-8B, can be further continue-tuned into our MoH models. Notably, MoH-LLaMA3-8B achieves an average accuracy of 64.0% across 14 benchmarks, outperforming LLaMA3-8B by 2.4% by utilizing only 75% of the attention heads.

The MoH model quickly recovers to over 95% of the performance of the original model within a training budget of 10B tokens. Then, the performance gradually improves with the increase of the training tokens.

๐Ÿค– API for Model Inference

If you want to load the model from the model hub on Hugging Face or on local, you can use the following code snippets.

Base Model Inference

from transformers import AutoModelForCausalLM, AutoTokenizer

question = "Hello!"

model = AutoModelForCausalLM.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True)

inputs = tokenizer(question, return_tensors='pt').to(model.device)
response = model.generate(inputs.input_ids, max_length=128)
print(tokenizer.decode(response.cpu()[0], skip_special_tokens=True))

Chat Model Inference

Coming soon...

๐Ÿ—๏ธ Training & Validating

  • The training code is built on Skywork-MoE. Unless Skywork-MoE is open source, we can't open source MoH-LLaMA3 alone. We will release the training code after the approval is completed.
  • The evaluation is performed on multiple key benchmarks using the Eleuther AI Language Model Evaluation Harness.
# For example, test MoH-LLaMA3-8B on winogrande

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
--main_process_port 2004 -m lm_eval --model hf \
--model_args pretrained=Chat-UniVi/MoH-LLaMA3-8B \
--tasks winogrande \
--batch_size 1 \
--output_path Results/winogrande

โœ๏ธ Citation

If you find this paper useful, please consider staring ๐ŸŒŸ this repo and citing ๐Ÿ“‘ our paper:

@article{jin2024moh,
  title={MoH: Multi-Head Attention as Mixture-of-Head Attention}, 
  author={Peng Jin and Bo Zhu and Li Yuan and Shuicheng Yan},
  year={2024}
}
Downloads last month
242
Inference API
Unable to determine this model's library. Check the docs .