Upload 5 files
Browse files- README.md +68 -3
- config.json +48 -0
- configuration_rene.py +103 -0
- model.safetensors +3 -0
- rene.py +435 -0
README.md
CHANGED
@@ -1,3 +1,68 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
datasets:
|
6 |
+
- allenai/dolma
|
7 |
+
tags:
|
8 |
+
- rene
|
9 |
+
- mamba
|
10 |
+
- cartesia
|
11 |
+
---
|
12 |
+
|
13 |
+
# Model Card for Rene
|
14 |
+
|
15 |
+
Rene is a 1.3 billion-parameter language model trained by [Cartesia](https://cartesia.ai).
|
16 |
+
Rene has a hybrid architecture based on [Mamba-2](https://arxiv.org/abs/2405.21060), with feedforward and sliding window attention layers interspersed.
|
17 |
+
It uses the [allenai/OLMo-1B-hf](https://huggingface.co/allenai/OLMo-1B-hf) tokenizer.
|
18 |
+
Rene was pretrained on 1.5 trillion tokens of the [Dolma-1.7](https://huggingface.co/datasets/allenai/dolma) dataset.
|
19 |
+
For more details, see our [blog post](https://cartesia.ai/blog/on-device).
|
20 |
+
|
21 |
+
## Usage
|
22 |
+
### Installation
|
23 |
+
The Rene model depends on the `cartesia-pytorch` package, which can be installed with `pip` as follows:
|
24 |
+
```shell
|
25 |
+
pip install --no-binary :all: cartesia-pytorch
|
26 |
+
```
|
27 |
+
|
28 |
+
### Generation example
|
29 |
+
```python
|
30 |
+
from cartesia_pytorch import ReneLMHeadModel
|
31 |
+
from transformers import AutoTokenizer
|
32 |
+
|
33 |
+
model = ReneLMHeadModel.from_pretrained("cartesia-ai/Rene-v0.1-1.3b-pytorch").half().cuda()
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf")
|
35 |
+
in_message = ["Rene Descartes was"]
|
36 |
+
inputs = tokenizer(in_message, return_tensors="pt")
|
37 |
+
outputs = model.generate(inputs.input_ids.cuda(), max_length=50, top_k=100, top_p=0.99)
|
38 |
+
out_message = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
39 |
+
print(out_message)
|
40 |
+
# Example output: "Rene Descartes was a French mathematician, philosopher, and scientist. Descartes is famously credited for creating the Cartesian coordinate system: a 3 dimensional representation of points, vectors, and directions. This work is, for the most part" ...
|
41 |
+
```
|
42 |
+
|
43 |
+
### Evaluation example
|
44 |
+
You can use our `cartesia_lm_eval` wrapper around the [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) to evaluate our model on standard text benchmarks. Example command (clone this repo and run the below from within the `cartesia-pytorch` directory):
|
45 |
+
```shell
|
46 |
+
python -m evals.cartesia_lm_eval --model rene_ssm --model_args pretrained=cartesia-ai/Rene-v0.1-1.3b-pytorch,trust_remote_code=True --trust_remote_code --tasks copa,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --cache_requests true --batch_size auto:4 --output_path outputs/rene_evals/
|
47 |
+
```
|
48 |
+
## Results on common benchmarks
|
49 |
+
| Model | Params (B) | Train Tokens | COPA | HellaSwag | MMLU (5-shot) | PIQA | ARC-e | ARC-c | WinoGrande | OpenBookQA | Average |
|
50 |
+
|------------------------------------------------|------------|--------------|------|-----------|---------------|------|-------|-------|------------|------------|---------|
|
51 |
+
| allenai/OLMo-1B-hf | 1.2 | 3.0 | 82.0 | 62.9 | 26.2 | 75.1 | 57.4 | 31.1 | 60.0 | 36.2 | 53.9 |
|
52 |
+
| apple/OpenELM-1\_1B | 1.1 | 1.5 | 81.0 | 64.8 | 27.1 | 75.6 | 55.4 | 32.3 | 61.9 | 36.2 | 54.3 |
|
53 |
+
| state-spaces/mamba2-1.3b | 1.3 | 0.3 | 82.0 | 60.0 | 25.8 | 73.7 | 64.2 | 33.3 | 61.0 | 37.8 | 54.7 |
|
54 |
+
| microsoft/phi-1\_5 | 1.4 | 0.15 | 79.0 | 62.6 | 42.5 | 75.5 | 73.2 | 48.0 | 72.8 | 48.0 | 62.7 |
|
55 |
+
| Qwen/Qwen2-1.5B | 1.5 | 7.0 | 80.0 | 65.4 | 56.0 | 75.5 | 60.4 | 35.0 | 65.8 | 36.4 | 59.3 |
|
56 |
+
| RWKV/rwkv-6-world-1b6 | 1.6 | 1.1 | 84.0 | 58.3 | 25.9 | 73.5 | 56.7 | 34.1 | 60.0 | 37.4 | 53.7 |
|
57 |
+
| stabilityai/stablelm-2-1\_6b | 1.6 | 4.0 | 86.0 | 69.0 | 38.1 | 76.7 | 68.1 | 38.9 | 63.6 | 38.8 | 59.9 |
|
58 |
+
| HuggingFaceTB/SmolLM-1.7B | 1.7 | 1.0 | 76.0 | 65.8 | 29.9 | 76.1 | 73.5 | 46.4 | 60.9 | 42.0 | 58.8 |
|
59 |
+
| h2oai/h2o-danube2-1.8b-base | 1.8 | 3.0 | 82.0 | 72.4 | 39.9 | 77.3 | 69.0 | 39.9 | 63.9 | 41.4 | 60.7 |
|
60 |
+
| google/recurrentgemma-2b | 2.7 | 2.0 | 62.0 | 61.8 | 32.3 | 68.8 | 46.4 | 29.9 | 57.1 | 29.0 | 48.4 |
|
61 |
+
| cognitivecomputations/TinyDolphin-2.8.1-1.1b | 1.1 | | 71.0 | 59.9 | 25.7 | 73.1 | 55.8 | 33.0 | 59.7 | 36.6 | 51.9 |
|
62 |
+
| cartesia-ai/Rene-v0.1-1.3b-pytorch (OUR MODEL) | 1.3 | 1.5 | 82.0 | 69.4 | 32.6 | 77.5 | 61.7 | 34.4 | 62.9 | 39.2 | 57.5 |
|
63 |
+
|
64 |
+
## Bias, Risks, and Limitations
|
65 |
+
Rene is a pretrained base model which has not undergone any alignment or instruction tuning, and therefore does not have any moderation or safety guarantees. Users should implement appropriate guardrails and moderation mechanisms based on their particular needs in order to ensure responsible and ethical usage.
|
66 |
+
|
67 |
+
## About Cartesia
|
68 |
+
At [Cartesia](https://cartesia.ai/), we're building real-time multimodal intelligence for every device.
|
config.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"attn_cfg": {
|
3 |
+
"causal": true,
|
4 |
+
"head_dim": 64,
|
5 |
+
"num_heads": 48,
|
6 |
+
"out_proj_bias": true,
|
7 |
+
"qkv_proj_bias": true,
|
8 |
+
"sliding_window_length": 2048
|
9 |
+
},
|
10 |
+
"attn_layer_idx": [
|
11 |
+
6,
|
12 |
+
18,
|
13 |
+
30,
|
14 |
+
42
|
15 |
+
],
|
16 |
+
"d_model": 2048,
|
17 |
+
"eos_token_id": 50279,
|
18 |
+
"mlp_cfg": {},
|
19 |
+
"mlp_layer_idx": [
|
20 |
+
2,
|
21 |
+
5,
|
22 |
+
8,
|
23 |
+
11,
|
24 |
+
14,
|
25 |
+
17,
|
26 |
+
20,
|
27 |
+
23,
|
28 |
+
26,
|
29 |
+
29,
|
30 |
+
32,
|
31 |
+
35,
|
32 |
+
38,
|
33 |
+
41,
|
34 |
+
44,
|
35 |
+
47
|
36 |
+
],
|
37 |
+
"model_type": "rene",
|
38 |
+
"n_layer": 48,
|
39 |
+
"pad_token_id": 1,
|
40 |
+
"pad_vocab_size_multiple": 16,
|
41 |
+
"residual_in_fp32": true,
|
42 |
+
"rms_norm": true,
|
43 |
+
"ssm_cfg": {
|
44 |
+
"norm_before_gate": true
|
45 |
+
},
|
46 |
+
"tie_word_embeddings": true,
|
47 |
+
"vocab_size": 50280
|
48 |
+
}
|
configuration_rene.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
|
3 |
+
from transformers.configuration_utils import PretrainedConfig
|
4 |
+
|
5 |
+
|
6 |
+
class ReneConfig(PretrainedConfig):
|
7 |
+
r"""Configuration class for the Rene model.
|
8 |
+
|
9 |
+
This is the configuration class to store the configuration of a [`ReneLMHeadModel`].
|
10 |
+
It is used to instantiate a Rene model according to the specified arguments,
|
11 |
+
defining the model architecture. Instantiating a configuration with the defaults will yield
|
12 |
+
a similar configuration to that of the Rene-v0.1-1.3b-pytorch model.
|
13 |
+
[cartesia-ai/Rene-v0.1-1.3b-pytorch](https://huggingface.co/cartesia-ai/Rene-v0.1-1.3b-pytorch)
|
14 |
+
|
15 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
16 |
+
documentation from [`PretrainedConfig`] for more information.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
d_model (`int`, *optional*, defaults to 2048):
|
20 |
+
Dimension of the hidden representations.
|
21 |
+
n_layer (`int`, *optional*, defaults to 48):
|
22 |
+
Number of architecture blocks.
|
23 |
+
vocab_size (`int`, *optional*, defaults to 50280):
|
24 |
+
Vocabulary size of the Rene model. Defines the number of different tokens that can be represented by the
|
25 |
+
`inputs_ids` passed when calling [`ReneModel`].
|
26 |
+
ssm_cfg (`dict`, *optional*):
|
27 |
+
Configuration parameters for the SSM layers.
|
28 |
+
attn_layer_idx (`List[int]`, *optional*):
|
29 |
+
Indices of the architecture blocks that should have attention layers.
|
30 |
+
attn_cfg (`dict`, *optional*):
|
31 |
+
Configuration parameters for the attention layers.
|
32 |
+
mlp_layer_idx (`List[int]`, *optional*):
|
33 |
+
Indices of the architecture blocks that should have MLP layers.
|
34 |
+
mlp_cfg (`dict`, *optional*):
|
35 |
+
Configuration parameters for the MLP layers.
|
36 |
+
rms_norm (`bool`, *optional*, defaults to `True`):
|
37 |
+
Whether to use RMSNorm (instead of LayerNorm).
|
38 |
+
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
|
39 |
+
Whether to keep residual values in fp32.
|
40 |
+
pad_vocab_size_multiple (`int`, *optional*, defaults to 16):
|
41 |
+
Pad the vocabulary size up to the next multiple of this value.
|
42 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
43 |
+
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
44 |
+
model has a output word embedding layer.
|
45 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
46 |
+
The id of the padding token.
|
47 |
+
bos_token_id (`int`, *optional*):
|
48 |
+
The id of the "beginning-of-sequence" token.
|
49 |
+
eos_token_id (`int`, *optional*, defaults to 50279):
|
50 |
+
The id of the "end-of-sequence" token.
|
51 |
+
"""
|
52 |
+
|
53 |
+
model_type = "rene"
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
d_model: int = 2048,
|
58 |
+
n_layer: int = 48,
|
59 |
+
vocab_size: int = 50280,
|
60 |
+
ssm_cfg: Optional[Dict] = None,
|
61 |
+
attn_layer_idx: Optional[List] = None,
|
62 |
+
attn_cfg: Optional[Dict] = None,
|
63 |
+
mlp_layer_idx: Optional[List] = None,
|
64 |
+
mlp_cfg: Optional[Dict] = None,
|
65 |
+
rms_norm: bool = True,
|
66 |
+
residual_in_fp32: bool = True,
|
67 |
+
pad_vocab_size_multiple: int = 16,
|
68 |
+
tie_word_embeddings: bool = True,
|
69 |
+
pad_token_id=1,
|
70 |
+
bos_token_id=None,
|
71 |
+
eos_token_id=50279,
|
72 |
+
**kwargs,
|
73 |
+
):
|
74 |
+
if ssm_cfg is None:
|
75 |
+
ssm_cfg = {}
|
76 |
+
if attn_layer_idx is None:
|
77 |
+
attn_layer_idx = []
|
78 |
+
if attn_cfg is None:
|
79 |
+
attn_cfg = {}
|
80 |
+
if mlp_layer_idx is None:
|
81 |
+
mlp_layer_idx = []
|
82 |
+
if mlp_cfg is None:
|
83 |
+
mlp_cfg = {}
|
84 |
+
|
85 |
+
self.d_model = d_model
|
86 |
+
self.n_layer = n_layer
|
87 |
+
self.vocab_size = vocab_size
|
88 |
+
self.ssm_cfg = ssm_cfg
|
89 |
+
self.attn_layer_idx = attn_layer_idx
|
90 |
+
self.attn_cfg = attn_cfg
|
91 |
+
self.mlp_layer_idx = mlp_layer_idx
|
92 |
+
self.mlp_cfg = mlp_cfg
|
93 |
+
self.rms_norm = rms_norm
|
94 |
+
self.residual_in_fp32 = residual_in_fp32
|
95 |
+
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
96 |
+
self.tie_word_embeddings = tie_word_embeddings
|
97 |
+
super().__init__(
|
98 |
+
bos_token_id=bos_token_id,
|
99 |
+
eos_token_id=eos_token_id,
|
100 |
+
pad_token_id=pad_token_id,
|
101 |
+
tie_word_embeddings=tie_word_embeddings,
|
102 |
+
**kwargs,
|
103 |
+
)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5a62c98beb82cd70e4ff866b3cd479f836f17676a76b82a337a1dde2126673de
|
3 |
+
size 2866628624
|
rene.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from flash_attn import flash_attn_with_kvcache
|
8 |
+
from mamba_ssm.models.mixer_seq_simple import _init_weights
|
9 |
+
from mamba_ssm.modules.mamba2 import Mamba2
|
10 |
+
from mamba_ssm.modules.mha import _update_kv_cache
|
11 |
+
from mamba_ssm.utils.generation import GenerationMixin as MambaGenerationMixin
|
12 |
+
from transformers.modeling_outputs import CausalLMOutput
|
13 |
+
from transformers.modeling_utils import PreTrainedModel
|
14 |
+
|
15 |
+
from .configuration_rene import ReneConfig
|
16 |
+
|
17 |
+
|
18 |
+
class ReneMLP(nn.Module):
|
19 |
+
"""One-hidden-layer network with GELU activation.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
d_input: Block input dimension.
|
23 |
+
d_output: Block output dimension.
|
24 |
+
expand: Block expansion factor.
|
25 |
+
bias: Use biases in linear layers.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, d_input, d_output=None, expand=3, bias=True, device=None, dtype=None):
|
29 |
+
super().__init__()
|
30 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
31 |
+
self.d_input = d_input
|
32 |
+
self.d_output = d_input if d_output is None else d_output
|
33 |
+
self.d_inner = int(round(expand * d_input))
|
34 |
+
self.in_proj = nn.Linear(self.d_input, self.d_inner, bias=bias, **factory_kwargs)
|
35 |
+
self.activation = nn.GELU()
|
36 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_input, bias=bias, **factory_kwargs)
|
37 |
+
|
38 |
+
def forward(self, x, inference_params=None):
|
39 |
+
"""Forward pass through the MLP module."""
|
40 |
+
y = self.in_proj(x)
|
41 |
+
y = self.activation(y)
|
42 |
+
y = self.out_proj(y)
|
43 |
+
return y
|
44 |
+
|
45 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
46 |
+
"""Allocate inference cache for ReneMLP. (There is nothing to cache for this module)."""
|
47 |
+
return None
|
48 |
+
|
49 |
+
|
50 |
+
class ReneMHA(nn.Module):
|
51 |
+
"""Multi-head self-attention. Adapted from mamba_ssm MHA class."""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
embed_dim,
|
56 |
+
num_heads,
|
57 |
+
num_heads_kv=None,
|
58 |
+
head_dim=None, # If None, use embed_dim // num_heads
|
59 |
+
qkv_proj_bias=True,
|
60 |
+
out_proj_bias=True,
|
61 |
+
softmax_scale=None,
|
62 |
+
causal=True,
|
63 |
+
sliding_window_length=None, # If None, infinite context
|
64 |
+
layer_idx=None,
|
65 |
+
device=None,
|
66 |
+
dtype=None,
|
67 |
+
) -> None:
|
68 |
+
"""
|
69 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
70 |
+
return_residual: whether to return the input x along with the output. This is for
|
71 |
+
performance reason: for post-norm architecture, returning the input allows us
|
72 |
+
to fuse the backward of nn.Linear with the residual connection.
|
73 |
+
"""
|
74 |
+
super().__init__()
|
75 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
76 |
+
self.embed_dim = embed_dim
|
77 |
+
self.layer_idx = layer_idx
|
78 |
+
self.softmax_scale = softmax_scale
|
79 |
+
self.causal = causal
|
80 |
+
assert self.causal, "Rene does not yet support non-causal modeling"
|
81 |
+
|
82 |
+
self.num_heads = num_heads
|
83 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
84 |
+
assert (
|
85 |
+
self.num_heads % self.num_heads_kv == 0
|
86 |
+
), "num_heads must be divisible by num_heads_kv"
|
87 |
+
if head_dim is None:
|
88 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
89 |
+
self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
|
90 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
91 |
+
out_dim = self.head_dim * self.num_heads
|
92 |
+
|
93 |
+
self.sliding_window_length = sliding_window_length
|
94 |
+
if self.sliding_window_length is None:
|
95 |
+
self.window_size = (-1, -1)
|
96 |
+
else:
|
97 |
+
self.window_size = (self.sliding_window_length - 1, 0) # for flash_attn
|
98 |
+
|
99 |
+
self.in_proj = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
100 |
+
self.out_proj = nn.Linear(out_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
101 |
+
|
102 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
103 |
+
"""Allocate inference cache for the multi-head self-attention module."""
|
104 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
105 |
+
device = self.out_proj.weight.device
|
106 |
+
kv_cache = torch.empty(
|
107 |
+
batch_size,
|
108 |
+
max_seqlen,
|
109 |
+
2,
|
110 |
+
self.num_heads_kv,
|
111 |
+
self.head_dim,
|
112 |
+
dtype=dtype,
|
113 |
+
device=device,
|
114 |
+
)
|
115 |
+
return kv_cache, None
|
116 |
+
|
117 |
+
def _pytorch_attn(self, q, kv):
|
118 |
+
k, v = kv.unbind(dim=-3)
|
119 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
120 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
121 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
122 |
+
L, S = q.size(-2), k.size(-2)
|
123 |
+
if S > self.sliding_window_length:
|
124 |
+
attn_mask = (
|
125 |
+
torch.ones(L, S, dtype=torch.bool)
|
126 |
+
.tril(diagonal=0)
|
127 |
+
.triu(-self.window_size[0])
|
128 |
+
.to(device=q.device)
|
129 |
+
)
|
130 |
+
# Since we pass in an attn_mask explicitly, we need to pass is_causal=False to
|
131 |
+
# `scaled_dot_product_attention` (even though the attn_mask itself is in fact causal).
|
132 |
+
is_causal_arg = False
|
133 |
+
else:
|
134 |
+
# The previous branch would also handle this case correctly, but it is more efficient
|
135 |
+
# to omit the attn_mask when we don't need it.
|
136 |
+
attn_mask = None
|
137 |
+
is_causal_arg = True
|
138 |
+
return F.scaled_dot_product_attention(
|
139 |
+
q, k, v, attn_mask=attn_mask, is_causal=is_causal_arg, scale=self.softmax_scale
|
140 |
+
).transpose(1, 2)
|
141 |
+
|
142 |
+
def _update_kv_cache(self, kv, inference_params):
|
143 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)."""
|
144 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
145 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
146 |
+
|
147 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
148 |
+
"""Write kv to inference_params, then compute attention."""
|
149 |
+
if inference_params.seqlen_offset == 0 or flash_attn_with_kvcache is None:
|
150 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
151 |
+
kv = self._update_kv_cache(kv, inference_params)
|
152 |
+
return self._pytorch_attn(q, kv)
|
153 |
+
else:
|
154 |
+
batch = q.shape[0]
|
155 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
156 |
+
kv_cache = kv_cache[:batch]
|
157 |
+
cache_seqlens = (
|
158 |
+
inference_params.lengths_per_sample[:batch]
|
159 |
+
if inference_params.lengths_per_sample is not None
|
160 |
+
else inference_params.seqlen_offset
|
161 |
+
)
|
162 |
+
return flash_attn_with_kvcache(
|
163 |
+
q,
|
164 |
+
kv_cache[:, :, 0],
|
165 |
+
kv_cache[:, :, 1],
|
166 |
+
kv[:, :, 0],
|
167 |
+
kv[:, :, 1],
|
168 |
+
cache_seqlens=cache_seqlens,
|
169 |
+
softmax_scale=self.softmax_scale,
|
170 |
+
causal=self.causal,
|
171 |
+
window_size=self.window_size,
|
172 |
+
)
|
173 |
+
|
174 |
+
def forward(self, x, inference_params=None):
|
175 |
+
"""Forward pass through the multi-head self-attention module."""
|
176 |
+
if (
|
177 |
+
inference_params is not None
|
178 |
+
and self.layer_idx not in inference_params.key_value_memory_dict
|
179 |
+
):
|
180 |
+
inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
|
181 |
+
x.shape[0], inference_params.max_seqlen, dtype=x.dtype
|
182 |
+
)
|
183 |
+
qkv = self.in_proj(x)
|
184 |
+
q, kv = qkv.split(
|
185 |
+
[self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1
|
186 |
+
)
|
187 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
188 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
189 |
+
if inference_params is None:
|
190 |
+
context = self._pytorch_attn(q, kv)
|
191 |
+
else:
|
192 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
193 |
+
context = rearrange(context, "... h d -> ... (h d)")
|
194 |
+
out = self.out_proj(context)
|
195 |
+
return out
|
196 |
+
|
197 |
+
|
198 |
+
class Block(nn.Module):
|
199 |
+
"""Simple residual block with normalization that wraps an inner "mixer" module."""
|
200 |
+
|
201 |
+
def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, residual_in_fp32=False):
|
202 |
+
"""
|
203 |
+
dim: The dimension of the input data.
|
204 |
+
mixer_cls: The class of the mixer module.
|
205 |
+
norm_cls: The class of the normalization module.
|
206 |
+
residual_in_fp32: Whether to keep residuals in fp32.
|
207 |
+
"""
|
208 |
+
super().__init__()
|
209 |
+
self.residual_in_fp32 = residual_in_fp32
|
210 |
+
self.norm = norm_cls(dim)
|
211 |
+
self.mixer = mixer_cls(dim)
|
212 |
+
|
213 |
+
def forward(self, x, inference_params=None, **mixer_kwargs):
|
214 |
+
"""Forward pass through the block."""
|
215 |
+
y = self.norm(x.to(dtype=self.norm.weight.dtype))
|
216 |
+
y = self.mixer(y, inference_params=inference_params, **mixer_kwargs)
|
217 |
+
|
218 |
+
residual = x
|
219 |
+
if self.residual_in_fp32:
|
220 |
+
residual = residual.to(torch.float32)
|
221 |
+
y = y + residual
|
222 |
+
y = y.to(dtype=x.dtype)
|
223 |
+
|
224 |
+
return y
|
225 |
+
|
226 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
227 |
+
"""Allocate inference cache for the mixer module."""
|
228 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
229 |
+
|
230 |
+
|
231 |
+
def _create_block(
|
232 |
+
d_model,
|
233 |
+
norm_cls,
|
234 |
+
ssm_cfg=None,
|
235 |
+
attn_layer_idx=None,
|
236 |
+
attn_cfg=None,
|
237 |
+
mlp_layer_idx=None,
|
238 |
+
mlp_cfg=None,
|
239 |
+
residual_in_fp32=False,
|
240 |
+
layer_idx=None,
|
241 |
+
device=None,
|
242 |
+
dtype=None,
|
243 |
+
):
|
244 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
245 |
+
if ssm_cfg is None:
|
246 |
+
ssm_cfg = {}
|
247 |
+
if attn_layer_idx is None:
|
248 |
+
attn_layer_idx = []
|
249 |
+
if attn_cfg is None:
|
250 |
+
attn_cfg = {}
|
251 |
+
if mlp_layer_idx is None:
|
252 |
+
mlp_layer_idx = []
|
253 |
+
if mlp_cfg is None:
|
254 |
+
mlp_cfg = {}
|
255 |
+
if layer_idx in attn_layer_idx:
|
256 |
+
mixer_cls = partial(ReneMHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
257 |
+
elif layer_idx in mlp_layer_idx:
|
258 |
+
mixer_cls = partial(ReneMLP, **mlp_cfg, **factory_kwargs)
|
259 |
+
else:
|
260 |
+
mixer_cls = partial(Mamba2, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
261 |
+
return Block(d_model, mixer_cls, norm_cls=norm_cls, residual_in_fp32=residual_in_fp32)
|
262 |
+
|
263 |
+
|
264 |
+
class MixerModel(nn.Module):
|
265 |
+
"""Adapted from mamba_ssm.models.mixer_seq_simple.MixerModel."""
|
266 |
+
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
d_model: int,
|
270 |
+
n_layer: int,
|
271 |
+
vocab_size: int,
|
272 |
+
ssm_cfg=None,
|
273 |
+
attn_layer_idx=None,
|
274 |
+
attn_cfg=None,
|
275 |
+
mlp_layer_idx=None,
|
276 |
+
mlp_cfg=None,
|
277 |
+
norm_epsilon: float = 1e-5,
|
278 |
+
rms_norm: bool = False,
|
279 |
+
initializer_cfg=None,
|
280 |
+
residual_in_fp32=False,
|
281 |
+
device=None,
|
282 |
+
dtype=None,
|
283 |
+
) -> None:
|
284 |
+
super().__init__()
|
285 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
286 |
+
self.residual_in_fp32 = residual_in_fp32
|
287 |
+
|
288 |
+
if rms_norm:
|
289 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm as norm_cls_base
|
290 |
+
else:
|
291 |
+
norm_cls_base = nn.LayerNorm
|
292 |
+
norm_cls = partial(norm_cls_base, eps=norm_epsilon, **factory_kwargs)
|
293 |
+
|
294 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
295 |
+
|
296 |
+
self.layers = nn.ModuleList(
|
297 |
+
[
|
298 |
+
_create_block(
|
299 |
+
d_model,
|
300 |
+
norm_cls=norm_cls,
|
301 |
+
ssm_cfg=ssm_cfg,
|
302 |
+
attn_layer_idx=attn_layer_idx,
|
303 |
+
attn_cfg=attn_cfg,
|
304 |
+
mlp_layer_idx=mlp_layer_idx,
|
305 |
+
mlp_cfg=mlp_cfg,
|
306 |
+
residual_in_fp32=residual_in_fp32,
|
307 |
+
layer_idx=i,
|
308 |
+
**factory_kwargs,
|
309 |
+
)
|
310 |
+
for i in range(n_layer)
|
311 |
+
]
|
312 |
+
)
|
313 |
+
|
314 |
+
self.norm_f = norm_cls(d_model)
|
315 |
+
|
316 |
+
self.apply(
|
317 |
+
partial(
|
318 |
+
_init_weights,
|
319 |
+
n_layer=n_layer,
|
320 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
321 |
+
n_residuals_per_layer=1,
|
322 |
+
)
|
323 |
+
)
|
324 |
+
|
325 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
326 |
+
"""Allocate inference cache for all layers."""
|
327 |
+
return {
|
328 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
329 |
+
for i, layer in enumerate(self.layers)
|
330 |
+
}
|
331 |
+
|
332 |
+
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
333 |
+
"""Forward pass through the model."""
|
334 |
+
hidden_states = self.embedding(input_ids)
|
335 |
+
for layer in self.layers:
|
336 |
+
hidden_states = layer(hidden_states, inference_params=inference_params, **mixer_kwargs)
|
337 |
+
hidden_states = self.norm_f(hidden_states.to(dtype=self.norm_f.weight.dtype))
|
338 |
+
return hidden_states
|
339 |
+
|
340 |
+
|
341 |
+
class ReneLMHeadModel(PreTrainedModel, MambaGenerationMixin):
|
342 |
+
"""
|
343 |
+
Rene language model architecture.
|
344 |
+
Based on mamba_ssm.models.mixer_seq_simple.MambaLMHeadModel, with several adaptations.
|
345 |
+
"""
|
346 |
+
|
347 |
+
config_class = ReneConfig
|
348 |
+
base_model_prefix = "backbone"
|
349 |
+
_no_split_modules = ["Block", "Mamba2"]
|
350 |
+
supports_gradient_checkpointing = True
|
351 |
+
_is_stateful = True
|
352 |
+
_tied_weights_keys = ["lm_head.weight"]
|
353 |
+
|
354 |
+
def __init__(
|
355 |
+
self,
|
356 |
+
config: ReneConfig,
|
357 |
+
initializer_cfg=None,
|
358 |
+
device=None,
|
359 |
+
dtype=None,
|
360 |
+
) -> None:
|
361 |
+
super().__init__(config)
|
362 |
+
d_model = config.d_model
|
363 |
+
n_layer = config.n_layer
|
364 |
+
vocab_size = config.vocab_size
|
365 |
+
ssm_cfg = config.ssm_cfg
|
366 |
+
attn_layer_idx = config.attn_layer_idx
|
367 |
+
attn_cfg = config.attn_cfg
|
368 |
+
mlp_layer_idx = config.mlp_layer_idx
|
369 |
+
mlp_cfg = config.mlp_cfg
|
370 |
+
rms_norm = config.rms_norm
|
371 |
+
residual_in_fp32 = config.residual_in_fp32
|
372 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
373 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
374 |
+
|
375 |
+
if set(attn_layer_idx).intersection(mlp_layer_idx):
|
376 |
+
raise ValueError(f"Conflicting {attn_layer_idx=} and {mlp_layer_idx=}")
|
377 |
+
|
378 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
379 |
+
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
380 |
+
|
381 |
+
self.backbone = MixerModel(
|
382 |
+
d_model=d_model,
|
383 |
+
n_layer=n_layer,
|
384 |
+
vocab_size=vocab_size,
|
385 |
+
ssm_cfg=ssm_cfg,
|
386 |
+
attn_layer_idx=attn_layer_idx,
|
387 |
+
attn_cfg=attn_cfg,
|
388 |
+
mlp_layer_idx=mlp_layer_idx,
|
389 |
+
mlp_cfg=mlp_cfg,
|
390 |
+
rms_norm=rms_norm,
|
391 |
+
initializer_cfg=initializer_cfg,
|
392 |
+
residual_in_fp32=residual_in_fp32,
|
393 |
+
**factory_kwargs,
|
394 |
+
)
|
395 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
396 |
+
|
397 |
+
# Initialize weights
|
398 |
+
self.apply(
|
399 |
+
partial(
|
400 |
+
_init_weights,
|
401 |
+
n_layer=n_layer,
|
402 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
403 |
+
)
|
404 |
+
)
|
405 |
+
self.tie_weights()
|
406 |
+
|
407 |
+
def tie_weights(self):
|
408 |
+
"""Tie embeddings and softmax layer weights if specified by config."""
|
409 |
+
if self.config.tie_word_embeddings:
|
410 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
411 |
+
|
412 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
413 |
+
"""Allocate inference cache."""
|
414 |
+
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
415 |
+
|
416 |
+
def forward(
|
417 |
+
self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs
|
418 |
+
):
|
419 |
+
"""
|
420 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
421 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens.
|
422 |
+
"""
|
423 |
+
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
|
424 |
+
if num_last_tokens > 0:
|
425 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
426 |
+
lm_logits = self.lm_head(hidden_states)
|
427 |
+
|
428 |
+
return CausalLMOutput(logits=lm_logits)
|
429 |
+
|
430 |
+
def generate(self, *args, **kwargs):
|
431 |
+
"""
|
432 |
+
Calls the custom `generate` method from `mamba_ssm.utils.generation.GenerationMixin`.
|
433 |
+
Refer to that method for argument names and defaults.
|
434 |
+
"""
|
435 |
+
return MambaGenerationMixin.generate(self, *args, **kwargs)
|