YikangS commited on
Commit
6c9f683
1 Parent(s): 445846c

clean code

Browse files
Files changed (2) hide show
  1. configuration_jetmoe.py +0 -269
  2. modeling_jetmoe.py +0 -1399
configuration_jetmoe.py DELETED
@@ -1,269 +0,0 @@
1
- """ JetMoE model configuration"""
2
- from collections import OrderedDict
3
- from typing import Any, List, Mapping, Optional
4
-
5
- from transformers import PreTrainedTokenizer, TensorType, is_torch_available
6
- from transformers.configuration_utils import PretrainedConfig
7
- from transformers.onnx import OnnxConfigWithPast, PatchingSpec
8
- from transformers.utils import logging
9
- import torch.nn.init as init
10
- import json
11
-
12
- logger = logging.get_logger(__name__)
13
-
14
-
15
- class JetMoEConfig(PretrainedConfig):
16
- r"""
17
- This is the configuration class to store the configuration of a [`JetMoEModel`]. It is used to instantiate a
18
- JetMoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
19
- with the defaults will yield a similar configuration to that of the JetMoE
20
- [jetmoe-small](https://huggingface.co/jetmoe-small) architecture. Configuration objects
21
- inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from
22
- [`PretrainedConfig`] for more information.
23
-
24
- Args:
25
- vocab_size (`int`, *optional*, defaults to 50400):
26
- Vocabulary size of the JetMoE model. Defines the number of different tokens that can be represented by the
27
- `inputs_ids` passed when calling [`JetMoEModel`].
28
- n_positions (`int`, *optional*, defaults to 2048):
29
- The maximum sequence length that this model might ever be used with. Typically set this to something large
30
- just in case (e.g., 512 or 1024 or 2048).
31
- n_embd (`int`, *optional*, defaults to 4096):
32
- Dimensionality of the embeddings and hidden states.
33
- n_layer (`int`, *optional*, defaults to 28):
34
- Number of hidden layers in the Transformer encoder.
35
- n_head (`int`, *optional*, defaults to 16):
36
- Number of attention heads for each attention layer in the Transformer encoder.
37
- rotary_dim (`int`, *optional*, defaults to 64):
38
- Number of dimensions in the embedding that Rotary Position Embedding is applied to.
39
- n_inner (`int`, *optional*, defaults to None):
40
- Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
41
- activation_function (`str`, *optional*, defaults to `"gelu_new"`):
42
- Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
43
- resid_pdrop (`float`, *optional*, defaults to 0.1):
44
- The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
45
- embd_pdrop (`int`, *optional*, defaults to 0.1):
46
- The dropout ratio for the embeddings.
47
- attn_pdrop (`float`, *optional*, defaults to 0.1):
48
- The dropout ratio for the attention.
49
- layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
50
- The epsilon to use in the layer normalization layers.
51
- initializer_range (`float`, *optional*, defaults to 0.02):
52
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
53
- use_cache (`bool`, *optional*, defaults to `True`):
54
- Whether or not the model should return the last key/values attentions (not used by all models).
55
-
56
- Example:
57
-
58
- ```python
59
- >>> from transformers import JetMoEConfig, JetMoEModel
60
-
61
- >>> # Initializing a JetMoE 6B configuration
62
- >>> configuration = JetMoEConfig()
63
-
64
- >>> # Initializing a model (with random weights) from the configuration
65
- >>> model = JetMoEModel(configuration)
66
-
67
- >>> # Accessing the model configuration
68
- >>> configuration = model.config
69
- ```"""
70
- model_type = "jetmoe"
71
- attribute_map = {
72
- "max_position_embeddings": "n_positions",
73
- "hidden_size": "n_embd",
74
- "num_attention_heads": "n_head",
75
- "num_hidden_layers": "num_layers",
76
- }
77
-
78
- def __init__(
79
- self,
80
- vocab_size=50295,
81
- hidden_size=1024,
82
- num_layers=24,
83
- num_attention_heads=16,
84
- kv_channels = 128,
85
- ffn_hidden_size=2048,
86
- max_position_embeddings=4096,
87
- rotary_percent=1.0,
88
- activation_function="silu",
89
- glu=True,
90
- moe_num_experts=8,
91
- moe_top_k=2,
92
- use_cache=True,
93
- bos_token_id=1,
94
- eos_token_id=2,
95
- tie_word_embeddings=True,
96
- bias=True,
97
- rope_theta=10000.0,
98
- rms_norm_eps=1e-6,
99
- initializer_range=0.01,
100
- **kwargs,
101
- ):
102
- self.vocab_size = vocab_size
103
- self.hidden_size = hidden_size
104
- self.num_layers = num_layers
105
- self.num_attention_heads = num_attention_heads
106
- self.kv_channels = kv_channels
107
- self.ffn_hidden_size = ffn_hidden_size
108
- self.max_position_embeddings = max_position_embeddings
109
- self.rotary_percent = rotary_percent
110
- self.activation_function = activation_function
111
- self.glu = glu
112
- self.moe_num_experts = moe_num_experts
113
- self.moe_top_k = moe_top_k
114
- self.use_cache = use_cache
115
- self.initializer_range = initializer_range
116
-
117
- self.bos_token_id = bos_token_id
118
- self.eos_token_id = eos_token_id
119
-
120
- self.init_method = init.xavier_uniform_
121
- self.output_layer_init_method = init.xavier_uniform_
122
- self.bias = bias
123
- self.rope_theta = rope_theta
124
- self.rms_norm_eps = rms_norm_eps
125
-
126
- super().__init__(
127
- bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
128
- )
129
-
130
- def to_dict(self):
131
- """Returns a dictionary representation of the config, excluding non-serializable attributes."""
132
- return {k: v for k, v in self.__dict__.items() if k not in ['init_method', 'output_layer_init_method', 'torch_dtype', '_pre_quantization_dtype', 'quantization_config']}
133
-
134
- def to_json_string(self, use_diff=False):
135
- """Serializes this instance to a JSON string, excluding non-serializable attributes.
136
-
137
- Args:
138
- use_diff (bool): Whether to use differences with the default config. This argument is
139
- accepted for compatibility with the transformers library but is not
140
- used in this custom implementation.
141
- """
142
- config_dict = self.to_dict() # Assuming you have a to_dict method as shown earlier
143
- return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
144
-
145
- class JetMoEOnnxConfig(OnnxConfigWithPast):
146
- def __init__(
147
- self,
148
- config: PretrainedConfig,
149
- task: str = "default",
150
- patching_specs: List[PatchingSpec] = None,
151
- use_past: bool = False,
152
- ):
153
- """
154
- Initialize the JetMoEOnnxConfig.
155
-
156
- Args:
157
- config (PretrainedConfig): Pretrained model configuration.
158
- task (str): Task description.
159
- patching_specs (List[PatchingSpec]): List of patching specifications.
160
- use_past (bool): Whether to use past tokens in the configuration.
161
- """
162
- super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
163
- if not getattr(self._config, "pad_token_id", None):
164
- # TODO: how to do that better?
165
- self._config.pad_token_id = 0
166
-
167
- @property
168
- def inputs(self) -> Mapping[str, Mapping[int, str]]:
169
- """
170
- Define the input mappings.
171
-
172
- Returns:
173
- Mapping[str, Mapping[int, str]]: Input mappings.
174
- """
175
- common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
176
- if self.use_past:
177
- self.fill_with_past_key_values_(common_inputs, direction="inputs")
178
- common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
179
- else:
180
- common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
181
-
182
- return common_inputs
183
-
184
- @property
185
- def num_layers(self) -> int:
186
- """
187
- Get the number of layers.
188
-
189
- Returns:
190
- int: Number of layers.
191
- """
192
- return self._config.n_layer
193
-
194
- @property
195
- def num_attention_heads(self) -> int:
196
- """
197
- Get the number of attention heads.
198
-
199
- Returns:
200
- int: Number of attention heads.
201
- """
202
- return self._config.n_head
203
-
204
- def generate_dummy_inputs(
205
- self,
206
- tokenizer: PreTrainedTokenizer,
207
- batch_size: int = -1,
208
- seq_length: int = -1,
209
- is_pair: bool = False,
210
- framework: Optional[TensorType] = None,
211
- ) -> Mapping[str, Any]:
212
- """
213
- Generate dummy inputs for testing.
214
-
215
- Args:
216
- tokenizer (PreTrainedTokenizer): Pretrained tokenizer.
217
- batch_size (int): Batch size.
218
- seq_length (int): Sequence length.
219
- is_pair (bool): Whether the input is a pair.
220
- framework (Optional[TensorType]): Tensor framework.
221
-
222
- Returns:
223
- Mapping[str, Any]: Dummy inputs.
224
- """
225
- common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
226
- tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
227
- )
228
-
229
- # We need to order the input in the way they appears in the forward()
230
- ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
231
-
232
- # Need to add the past_keys
233
- if self.use_past:
234
- if not is_torch_available():
235
- raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
236
- else:
237
- import torch
238
-
239
- batch, seqlen = common_inputs["input_ids"].shape
240
- # Not using the same length for past_key_values
241
- past_key_values_length = seqlen + 2
242
- past_shape = (
243
- batch,
244
- self.num_attention_heads,
245
- past_key_values_length,
246
- self._config.hidden_size // self.num_attention_heads,
247
- )
248
- ordered_inputs["past_key_values"] = [
249
- (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
250
- ]
251
-
252
- ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
253
- if self.use_past:
254
- mask_dtype = ordered_inputs["attention_mask"].dtype
255
- ordered_inputs["attention_mask"] = torch.cat(
256
- [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
257
- )
258
-
259
- return ordered_inputs
260
-
261
- @property
262
- def default_onnx_opset(self) -> int:
263
- """
264
- Get the default ONNX opset version.
265
-
266
- Returns:
267
- int: Default ONNX opset version.
268
- """
269
- return 13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_jetmoe.py DELETED
@@ -1,1399 +0,0 @@
1
- """ PyTorch JetMoE model."""
2
-
3
- from typing import List, Optional, Tuple, Union
4
- import warnings, math
5
-
6
- import torch
7
- import torch.utils.checkpoint
8
- from torch import nn
9
- from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
- from torch.nn import functional as F
11
-
12
- import megablocks
13
- from transformers.modeling_outputs import (
14
- BaseModelOutputWithPast,
15
- CausalLMOutputWithPast,
16
- SequenceClassifierOutputWithPast,
17
- dataclass
18
- )
19
- from transformers.modeling_utils import PreTrainedModel
20
- from transformers.utils import (
21
- add_start_docstrings,
22
- add_start_docstrings_to_model_forward,
23
- is_flash_attn_2_available,
24
- is_flash_attn_greater_or_equal_2_10,
25
- replace_return_docstrings,
26
- logging
27
- )
28
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
29
- from transformers.cache_utils import Cache, DynamicCache
30
- from .configuration_jetmoe import JetMoEConfig
31
- from jetmoe_model.utils import moe
32
-
33
- if is_flash_attn_2_available():
34
- from flash_attn import flash_attn_func, flash_attn_varlen_func
35
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
36
-
37
- logger = logging.get_logger(__name__)
38
-
39
- _CHECKPOINT_FOR_DOC = "jetmoe"
40
- _CONFIG_FOR_DOC = "JetMoEConfig"
41
-
42
-
43
- @dataclass
44
- class JetMoEBaseModelOutputWithPast(BaseModelOutputWithPast):
45
- """
46
- Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
47
-
48
- Args:
49
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
50
- Sequence of hidden-states at the output of the last layer of the model.
51
-
52
- If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
53
- hidden_size)` is output.
54
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
55
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
56
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
57
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
58
- encoder_sequence_length, embed_size_per_head)`.
59
-
60
- Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
61
- `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
62
- input) to speed up sequential decoding.
63
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
64
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
65
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
66
-
67
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
68
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
69
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
70
- sequence_length)`.
71
-
72
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
73
- heads.
74
- """
75
-
76
- last_hidden_state: torch.FloatTensor = None
77
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
78
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
79
- attentions: Optional[Tuple[torch.FloatTensor]] = None
80
- aux_loss: Optional[torch.FloatTensor] = None
81
-
82
-
83
- @dataclass
84
- class JetMoECausalLMOutputWithPast(CausalLMOutputWithPast):
85
- """
86
- Base class for causal language model (or autoregressive) outputs.
87
-
88
- Args:
89
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
90
- Language modeling loss (for next-token prediction).
91
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
92
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
93
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
94
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
95
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
96
-
97
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
98
- `past_key_values` input) to speed up sequential decoding.
99
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
100
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
101
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
102
-
103
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
104
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
105
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
106
- sequence_length)`.
107
-
108
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
109
- heads.
110
- """
111
-
112
- loss: Optional[torch.FloatTensor] = None
113
- logits: torch.FloatTensor = None
114
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
115
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
116
- attentions: Optional[Tuple[torch.FloatTensor]] = None
117
- aux_loss: Optional[torch.FloatTensor] = None
118
-
119
-
120
- @dataclass
121
- class JetMoESequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast):
122
- """
123
- Base class for outputs of sentence classification models.
124
-
125
- Args:
126
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
127
- Classification (or regression if config.num_labels==1) loss.
128
- logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
129
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
130
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
131
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
132
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
133
-
134
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
135
- `past_key_values` input) to speed up sequential decoding.
136
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
137
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
138
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
139
-
140
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
141
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
142
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
143
- sequence_length)`.
144
-
145
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
146
- heads.
147
- """
148
-
149
- loss: Optional[torch.FloatTensor] = None
150
- logits: torch.FloatTensor = None
151
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
152
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
153
- attentions: Optional[Tuple[torch.FloatTensor]] = None
154
- aux_loss: Optional[torch.FloatTensor] = None
155
-
156
-
157
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
158
- def _get_unpad_data(attention_mask):
159
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
160
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
161
- max_seqlen_in_batch = seqlens_in_batch.max().item()
162
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
163
- return (
164
- indices,
165
- cu_seqlens,
166
- max_seqlen_in_batch,
167
- )
168
-
169
- class JetMoERMSNorm(nn.Module):
170
- def __init__(self, hidden_size, eps=1e-6):
171
- """
172
- JetMoERMSNorm module
173
- """
174
- super().__init__()
175
- self.weight = nn.Parameter(torch.ones(hidden_size))
176
- self.variance_epsilon = eps
177
-
178
- def forward(self, hidden_states):
179
- input_dtype = hidden_states.dtype
180
- hidden_states = hidden_states.to(torch.float32)
181
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
182
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
183
- return self.weight * hidden_states.to(input_dtype)
184
-
185
-
186
- # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
187
- class JetMoERotaryEmbedding(nn.Module):
188
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
189
- super().__init__()
190
-
191
- self.dim = dim
192
- self.max_position_embeddings = max_position_embeddings
193
- self.base = base
194
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
195
- self.register_buffer("inv_freq", inv_freq, persistent=False)
196
-
197
- # Build here to make `torch.jit.trace` work.
198
- self._set_cos_sin_cache(
199
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
200
- )
201
-
202
- def _set_cos_sin_cache(self, seq_len, device, dtype):
203
- self.max_seq_len_cached = seq_len
204
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
205
-
206
- freqs = torch.outer(t, self.inv_freq)
207
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
208
- emb = torch.cat((freqs, freqs), dim=-1)
209
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
210
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
211
-
212
- def forward(self, x, seq_len=None):
213
- # x: [bs, num_attention_heads, seq_len, head_size]
214
- if seq_len > self.max_seq_len_cached:
215
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
216
-
217
- return (
218
- self.cos_cached[:seq_len].to(dtype=x.dtype),
219
- self.sin_cached[:seq_len].to(dtype=x.dtype),
220
- )
221
-
222
-
223
- # Copied from transformers.models.llama.modeling_llama.rotate_half
224
- def rotate_half(x):
225
- """Rotates half the hidden dims of the input."""
226
- x1 = x[..., : x.shape[-1] // 2]
227
- x2 = x[..., x.shape[-1] // 2 :]
228
- return torch.cat((-x2, x1), dim=-1)
229
-
230
-
231
- # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
232
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2):
233
- """Applies Rotary Position Embedding to the query and key tensors.
234
-
235
- Args:
236
- q (`torch.Tensor`): The query tensor.
237
- k (`torch.Tensor`): The key tensor.
238
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
239
- sin (`torch.Tensor`): The sine part of the rotary embedding.
240
- position_ids (`torch.Tensor`):
241
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
242
- used to pass offsetted position ids when working with a KV-cache.
243
- unsqueeze_dim (`int`, *optional*, defaults to 1):
244
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
245
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
246
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
247
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
248
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
249
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
250
- Returns:
251
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
252
- """
253
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
254
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
255
- q_embed = (q * cos) + (rotate_half(q) * sin)
256
- k_embed = (k * cos) + (rotate_half(k) * sin)
257
- return q_embed, k_embed
258
-
259
-
260
- class JetMoEAttention(nn.Module):
261
- """
262
- Multi-headed attention from 'Attention Is All You Need' paper.
263
- """
264
-
265
- def __init__(self, config: JetMoEConfig, layer_idx: Optional[int] = None):
266
- """
267
- Initialize the JetMoEAttention module.
268
-
269
- Args:
270
- config: Configuration object with model hyperparameters.
271
- """
272
- super().__init__()
273
- self.config = config
274
- self.layer_idx = layer_idx
275
- self.is_causal = True
276
- if layer_idx is None:
277
- logger.warning_once(
278
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
279
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
280
- "when creating this class."
281
- )
282
-
283
- self.top_k = config.moe_top_k
284
-
285
- self.kv_projection_size = config.kv_channels * config.num_attention_heads
286
- self.num_key_value_heads = config.num_attention_heads
287
- self.num_heads = self.num_key_value_heads * self.top_k
288
- self.hidden_size_per_attention_head = config.kv_channels
289
-
290
- self.experts = moe.MoE(
291
- input_size=config.hidden_size,
292
- hidden_size=self.kv_projection_size,
293
- num_experts=config.moe_num_experts,
294
- top_k=config.moe_top_k,
295
- glu=False
296
- )
297
-
298
- self.kv_proj = torch.nn.Linear(
299
- config.hidden_size, self.kv_projection_size * 2, bias=False
300
- )
301
-
302
- self.rotary_emb = JetMoERotaryEmbedding(
303
- config.kv_channels,
304
- max_position_embeddings=config.max_position_embeddings,
305
- base=config.rope_theta,
306
- )
307
-
308
- # def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
309
- # return tensor.view(bsz, seq_len, self.num_attention_heads, self.hidden_size_per_attention_head).transpose(1, 2).contiguous()
310
-
311
- def forward(
312
- self,
313
- hidden_states: torch.Tensor,
314
- attention_mask: Optional[torch.Tensor] = None,
315
- position_ids: Optional[torch.LongTensor] = None,
316
- past_key_value: Optional[Cache] = None,
317
- output_attentions: bool = False,
318
- use_cache: bool = False,
319
- **kwargs,
320
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
321
- if "padding_mask" in kwargs:
322
- warnings.warn(
323
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
324
- )
325
- bsz, q_len, _ = hidden_states.size()
326
-
327
- query_states, aux_loss = self.experts.map(hidden_states)
328
- key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
329
-
330
- query_states = query_states.view(bsz, q_len, self.num_heads, self.hidden_size_per_attention_head).transpose(1, 2)
331
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head).transpose(1, 2)
332
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head).transpose(1, 2)
333
-
334
- kv_seq_len = key_states.shape[2]
335
- if past_key_value is not None:
336
- if self.layer_idx is None:
337
- raise ValueError(
338
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
339
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
340
- "with a layer index."
341
- )
342
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
343
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
344
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, unsqueeze_dim=1)
345
-
346
- if past_key_value is not None:
347
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
348
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
349
-
350
- # repeat k/v heads if n_kv_heads < n_heads
351
- key_states = key_states.repeat(1, self.top_k, 1, 1)
352
- value_states = value_states.repeat(1, self.top_k, 1, 1)
353
-
354
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.hidden_size_per_attention_head)
355
-
356
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
357
- raise ValueError(
358
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
359
- f" {attn_weights.size()}"
360
- )
361
-
362
- if attention_mask is not None:
363
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
364
- raise ValueError(
365
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
366
- )
367
-
368
- attn_weights = attn_weights + attention_mask
369
-
370
- # upcast attention to fp32
371
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
372
- # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
373
- attn_output = torch.matmul(attn_weights, value_states)
374
-
375
- if attn_output.size() != (bsz, self.num_heads, q_len, self.hidden_size_per_attention_head):
376
- raise ValueError(
377
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.hidden_size_per_attention_head)}, but is"
378
- f" {attn_output.size()}"
379
- )
380
-
381
- attn_output = attn_output.transpose(1, 2).contiguous()
382
- attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size)
383
-
384
- attn_output = self.experts.reduce(attn_output)
385
- attn_output = attn_output.view(bsz, q_len, -1)
386
-
387
- if not output_attentions:
388
- attn_weights = None
389
-
390
- return attn_output, attn_weights, past_key_value, aux_loss
391
-
392
-
393
- # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->JetMoE
394
- class JetMoESdpaAttention(JetMoEAttention):
395
- """
396
- JetMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
397
- `JetMoEAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
398
- SDPA API.
399
- """
400
-
401
- # Adapted from JetMoEAttention.forward
402
- def forward(
403
- self,
404
- hidden_states: torch.Tensor,
405
- attention_mask: Optional[torch.Tensor] = None,
406
- position_ids: Optional[torch.LongTensor] = None,
407
- past_key_value: Optional[Cache] = None,
408
- output_attentions: bool = False,
409
- use_cache: bool = False,
410
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
411
- if output_attentions:
412
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
413
- logger.warning_once(
414
- "JetMoEModel is using JetMoESdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
415
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
416
- )
417
- return super().forward(
418
- hidden_states=hidden_states,
419
- attention_mask=attention_mask,
420
- position_ids=position_ids,
421
- past_key_value=past_key_value,
422
- output_attentions=output_attentions,
423
- use_cache=use_cache,
424
- )
425
-
426
- bsz, q_len, _ = hidden_states.size()
427
-
428
- query_states, aux_loss = self.experts.map(hidden_states)
429
- key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
430
-
431
- query_states = query_states.view(bsz, q_len, self.num_heads, self.hidden_size_per_attention_head).transpose(1, 2)
432
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head).transpose(1, 2)
433
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.hidden_size_per_attention_head).transpose(1, 2)
434
-
435
- kv_seq_len = key_states.shape[2]
436
- if past_key_value is not None:
437
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
438
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
439
-
440
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids, unsqueeze_dim=1)
441
-
442
- if past_key_value is not None:
443
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
444
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
445
-
446
- key_states = key_states.repeat(1, self.top_k, 1, 1)
447
- value_states = value_states.repeat(1, self.top_k, 1, 1)
448
-
449
- if attention_mask is not None:
450
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
451
- raise ValueError(
452
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
453
- )
454
-
455
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
456
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
457
- if query_states.device.type == "cuda" and attention_mask is not None:
458
- query_states = query_states.contiguous()
459
- key_states = key_states.contiguous()
460
- value_states = value_states.contiguous()
461
-
462
- attn_output = torch.nn.functional.scaled_dot_product_attention(
463
- query_states,
464
- key_states,
465
- value_states,
466
- attn_mask=attention_mask,
467
- dropout_p=0.0,
468
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
469
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
470
- )
471
-
472
- attn_output = attn_output.transpose(1, 2).contiguous()
473
- attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size)
474
-
475
- attn_output = self.experts.reduce(attn_output)
476
- attn_output = attn_output.view(bsz, q_len, -1)
477
-
478
- return attn_output, None, past_key_value, aux_loss
479
-
480
-
481
- class JetMoEFlashAttention2(JetMoEAttention):
482
- def __init__(self, *args, **kwargs):
483
- super().__init__(*args, **kwargs)
484
-
485
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
486
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
487
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
488
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
489
-
490
- def forward(
491
- self,
492
- hidden_states: Optional[torch.FloatTensor],
493
- attention_mask: Optional[torch.FloatTensor] = None,
494
- position_ids: Optional[torch.LongTensor] = None,
495
- past_key_value: Optional[Cache] = None,
496
- use_cache: Optional[bool] = False,
497
- output_attentions: Optional[bool] = False,
498
- **kwargs,
499
- ) -> Union[
500
- Tuple[torch.Tensor, Tuple[torch.Tensor]],
501
- Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
502
- ]:
503
- """
504
- Forward pass of the JetMoEAttention module.
505
-
506
- Args:
507
- hidden_states (Optional[torch.FloatTensor]): Input hidden states.
508
- attention_mask (Optional[torch.FloatTensor]): Attention mask.
509
- layer_past (Optional[Tuple[torch.Tensor]]): Past layer state.
510
- use_cache (Optional[bool]): Whether to use cached states.
511
- output_attentions (Optional[bool]): Whether to output attention weights.
512
-
513
- Returns:
514
- Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[...]]]: Tuple containing outputs.
515
- """
516
- #assert attention_mask is None, "attention_mask is not supported"
517
- assert output_attentions is False, "output_attentions is not supported"
518
-
519
- B, T, C = hidden_states.size() # batch size, sequence length, embedding dimensionality (hidden_size)
520
-
521
- # calculate query, key, values
522
- query_layer, aux_loss = self.experts.map(hidden_states)
523
- key_layer, value_layer = self.kv_proj(hidden_states).chunk(2, dim=-1)
524
-
525
- query_layer = query_layer.view(B, T, self.num_heads, self.hidden_size_per_attention_head) # (B, T, k * nh, hs)
526
- key_layer = key_layer.view(B, T, self.num_key_value_heads, self.hidden_size_per_attention_head) # (B, T, nh, hs)
527
- value_layer = value_layer.view(B, T, self.num_key_value_heads, self.hidden_size_per_attention_head) # (B, T, nh, hs)
528
-
529
- kv_seq_len = key_layer.shape[1]
530
- if past_key_value is not None:
531
- if self.layer_idx is None:
532
- raise ValueError(
533
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
534
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
535
- "with a layer index."
536
- )
537
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
538
- cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
539
- query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
540
-
541
- # query_layer = query_layer.contiguous()
542
- # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
543
- key_layer = key_layer.repeat(1, 1, self.top_k, 1)
544
- value_layer = value_layer.repeat(1, 1, self.top_k, 1)
545
-
546
- if past_key_value is not None:
547
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
548
- # print(self.layer_idx, key_layer.size())
549
- key_layer = key_layer.transpose(1, 2)
550
- value_layer = value_layer.transpose(1, 2)
551
- key_layer, value_layer = past_key_value.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
552
- key_layer = key_layer.transpose(1, 2)
553
- value_layer = value_layer.transpose(1, 2)
554
-
555
- context_layer = self._flash_attention_forward(
556
- query_layer,
557
- key_layer,
558
- value_layer,
559
- attention_mask,
560
- T,
561
- )
562
-
563
- # output projection
564
- y = self.experts.reduce(context_layer.reshape(T, B, self.top_k, self.kv_projection_size))
565
- y = y.view(B, T, C) # re-assemble all head outputs side by side
566
-
567
- if not output_attentions:
568
- attn_weights = None
569
-
570
- return y, attn_weights, past_key_value, aux_loss
571
-
572
- def _flash_attention_forward(
573
- self,
574
- query_states,
575
- key_states,
576
- value_states,
577
- attention_mask,
578
- query_length,
579
- dropout=0.0,
580
- softmax_scale=None,
581
- ):
582
- """
583
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
584
- first unpad the input, then computes the attention scores and pad the final attention scores.
585
-
586
- Args:
587
- query_states (`torch.Tensor`):
588
- Input query states to be passed to Flash Attention API
589
- key_states (`torch.Tensor`):
590
- Input key states to be passed to Flash Attention API
591
- value_states (`torch.Tensor`):
592
- Input value states to be passed to Flash Attention API
593
- attention_mask (`torch.Tensor`):
594
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
595
- position of padding tokens and 1 for the position of non-padding tokens.
596
- dropout (`float`):
597
- Attention dropout
598
- softmax_scale (`float`, *optional*):
599
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
600
- """
601
- if not self._flash_attn_uses_top_left_mask:
602
- causal = self.is_causal
603
- else:
604
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
605
- causal = self.is_causal and query_length != 1
606
-
607
- # Contains at least one padding token in the sequence
608
- if attention_mask is not None:
609
- batch_size = query_states.shape[0]
610
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
611
- query_states, key_states, value_states, attention_mask, query_length
612
- )
613
-
614
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
615
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
616
-
617
- attn_output_unpad = flash_attn_varlen_func(
618
- query_states,
619
- key_states,
620
- value_states,
621
- cu_seqlens_q=cu_seqlens_q,
622
- cu_seqlens_k=cu_seqlens_k,
623
- max_seqlen_q=max_seqlen_in_batch_q,
624
- max_seqlen_k=max_seqlen_in_batch_k,
625
- dropout_p=dropout,
626
- softmax_scale=softmax_scale,
627
- causal=causal,
628
- )
629
-
630
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
631
- else:
632
- attn_output = flash_attn_func(
633
- query_states,
634
- key_states,
635
- value_states,
636
- dropout,
637
- softmax_scale=softmax_scale,
638
- causal=causal
639
- )
640
-
641
- return attn_output
642
-
643
-
644
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
645
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
646
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
647
-
648
- key_layer = index_first_axis(
649
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
650
- )
651
- value_layer = index_first_axis(
652
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
653
- )
654
- if query_length == kv_seq_len:
655
- query_layer = index_first_axis(
656
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
657
- )
658
- cu_seqlens_q = cu_seqlens_k
659
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
660
- indices_q = indices_k
661
- elif query_length == 1:
662
- max_seqlen_in_batch_q = 1
663
- cu_seqlens_q = torch.arange(
664
- batch_size + 1, dtype=torch.int32, device=query_layer.device
665
- ) # There is a memcpy here, that is very bad.
666
- indices_q = cu_seqlens_q[:-1]
667
- query_layer = query_layer.squeeze(1)
668
- else:
669
- # The -q_len: slice assumes left padding.
670
- attention_mask = attention_mask[:, -query_length:]
671
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
672
-
673
- return (
674
- query_layer,
675
- key_layer,
676
- value_layer,
677
- indices_q,
678
- (cu_seqlens_q, cu_seqlens_k),
679
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
680
- )
681
-
682
-
683
- JETMOE_ATTENTION_CLASSES = {
684
- "eager": JetMoEAttention,
685
- "flash_attention_2": JetMoEFlashAttention2,
686
- "sdpa": JetMoESdpaAttention,
687
- }
688
-
689
-
690
- class JetMoEBlock(nn.Module):
691
- def __init__(self, config: JetMoEConfig, layer_idx: Optional[int] = None):
692
- """
693
- Initialize the JetMoEBlock module.
694
-
695
- Args:
696
- config: Configuration object with model hyperparameters.
697
- """
698
- super().__init__()
699
- self.input_layernorm = JetMoERMSNorm(config.hidden_size)
700
- #self.self_attention = JetMoEAttention(config, layer_idx)
701
- self.self_attention = JETMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
702
- self.post_attention_layernorm = JetMoERMSNorm(config.hidden_size)
703
-
704
- moe_args = megablocks.layers.arguments.from_megatron(config)
705
- moe_args.activation_fn = F.silu
706
- moe_args.return_bias = False
707
- # self.mlp = megablocks.layers.dmoe.dMoE(moe_args)
708
- self.mlp = moe.MoE(
709
- input_size=config.hidden_size,
710
- hidden_size=config.ffn_hidden_size,
711
- num_experts=config.moe_num_experts,
712
- activation=F.silu,
713
- top_k=config.moe_top_k,
714
- glu=config.glu
715
- )
716
-
717
- def forward(
718
- self,
719
- hidden_states: Optional[torch.FloatTensor],
720
- position_ids: Optional[torch.LongTensor] = None,
721
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
722
- attention_mask: Optional[torch.FloatTensor] = None,
723
- output_attentions: Optional[bool] = False,
724
- use_cache: Optional[bool] = False,
725
- **kwargs,
726
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
727
- """
728
- Forward pass of the JetMoEBlock module.
729
-
730
- Args:
731
- hidden_states (Optional[torch.FloatTensor]): Input hidden states.
732
- layer_past (Optional[Tuple[torch.Tensor]]): Past layer state.
733
- attention_mask (Optional[torch.FloatTensor]): Attention mask.
734
- head_mask (Optional[torch.FloatTensor]): Head mask.
735
- use_cache (Optional[bool]): Whether to use cached states.
736
- output_attentions (Optional[bool]): Whether to output attention weights.
737
-
738
- Returns:
739
- Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
740
- Tuple containing outputs or optional attention weights.
741
- """
742
- # Self Attention
743
- attn_output, self_attn_weights, present_key_value, att_aux_loss = self.self_attention(
744
- hidden_states=self.input_layernorm(hidden_states),
745
- attention_mask=attention_mask,
746
- position_ids=position_ids,
747
- past_key_value=past_key_value,
748
- output_attentions=output_attentions,
749
- use_cache=use_cache,
750
- )
751
-
752
- hidden_states = hidden_states + attn_output
753
- x_mlp, mlp_aux_loss = self.mlp(self.post_attention_layernorm(hidden_states))
754
- hidden_states = hidden_states + x_mlp
755
-
756
- outputs = (hidden_states,)
757
-
758
- if output_attentions:
759
- outputs += (self_attn_weights,)
760
-
761
- if use_cache:
762
- outputs += (present_key_value,)
763
-
764
- outputs += (att_aux_loss + mlp_aux_loss,)
765
-
766
- return outputs
767
-
768
-
769
-
770
- class JetMoEPreTrainedModel(PreTrainedModel):
771
- """
772
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
773
- models.
774
- """
775
-
776
- config_class = JetMoEConfig
777
- base_model_prefix = "transformer"
778
- supports_gradient_checkpointing = True
779
- _no_split_modules = ["JetMoEBlock"]
780
- _skip_keys_device_placement = "past_key_values"
781
- _supports_flash_attn_2 = True
782
- _supports_sdpa = True
783
- _supports_cache_class = True
784
-
785
- def __init__(self, *inputs, **kwargs):
786
- """
787
- Initialize the JetMoEPreTrainedModel.
788
-
789
- Args:
790
- *inputs: Variable length input arguments.
791
- **kwargs: Keyword arguments.
792
- """
793
- super().__init__(*inputs, **kwargs)
794
-
795
- self.gradient_checkpointing = False
796
-
797
- def _init_weights(self, module):
798
- """Initialize the weights."""
799
- if isinstance(module, (nn.Linear,)):
800
- # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
801
- # cf https://github.com/pytorch/pytorch/pull/5617
802
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
803
- if module.bias is not None:
804
- module.bias.data.zero_()
805
- elif isinstance(module, nn.Embedding):
806
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
807
- if module.padding_idx is not None:
808
- module.weight.data[module.padding_idx].zero_()
809
- elif isinstance(module, nn.LayerNorm):
810
- module.bias.data.zero_()
811
- module.weight.data.fill_(1.0)
812
-
813
- # def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={}):
814
- # for module in self.modules():
815
- # if hasattr(module, "gradient_checkpointing"):
816
- # self._set_gradient_checkpointing(
817
- # module, True, gradient_checkpointing_kwargs
818
- # )
819
-
820
- # def gradient_checkpointing_disable(self):
821
- # for module in self.modules():
822
- # if hasattr(module, "gradient_checkpointing"):
823
- # self._set_gradient_checkpointing(
824
- # module, False
825
- # )
826
-
827
- # def _set_gradient_checkpointing(
828
- # self,
829
- # module,
830
- # value=False,
831
- # gradient_checkpointing_kwargs={"use_reentrant": False},
832
- # ):
833
- # """
834
- # Set gradient checkpointing for the JetMoEModel.
835
-
836
- # Args:
837
- # module: The module for which gradient checkpointing is set.
838
- # value (bool): Whether to enable gradient checkpointing.
839
- # """
840
- # self._gradient_checkpointing_func = checkpoint
841
- # self.gradient_checkpointing = True
842
- # if isinstance(module, JetMoEModel):
843
- # module.gradient_checkpointing = value
844
- # module.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs
845
- # module._gradient_checkpointing_func = checkpoint
846
-
847
- MODULEFORMER_START_DOCSTRING = r"""
848
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
849
- it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
850
- behavior.
851
-
852
- Parameters:
853
- config ([`JetMoEConfig`]): Model configuration class with all the parameters of the model.
854
- Initializing with a config file does not load the weights associated with the model, only the
855
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
856
- """
857
-
858
- MODULEFORMER_INPUTS_DOCSTRING = r"""
859
- Args:
860
- input_ids (`torch.LongTensor` of shape `({0})`):
861
- Indices of input sequence tokens in the vocabulary.
862
-
863
- Indices can be obtained using [`AutoProcenizer`]. See [`PreTrainedTokenizer.encode`] and
864
- [`PreTrainedTokenizer.__call__`] for details.
865
-
866
- [What are input IDs?](../glossary#input-ids)
867
- attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
868
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
869
-
870
- - 1 for tokens that are **not masked**,
871
- - 0 for tokens that are **masked**.
872
-
873
- [What are attention masks?](../glossary#attention-mask)
874
- token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
875
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
876
- 1]`:
877
-
878
- - 0 corresponds to a *sentence A* token,
879
- - 1 corresponds to a *sentence B* token.
880
-
881
- [What are token type IDs?](../glossary#token-type-ids)
882
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
883
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
884
- config.n_positions - 1]`.
885
-
886
- [What are position IDs?](../glossary#position-ids)
887
- head_mask (`torch.FloatTensor` of shape `(num_attention_heads,)` or `(n_layer, num_attention_heads)`, *optional*):
888
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
889
-
890
- - 1 indicates the head is **not masked**,
891
- - 0 indicates the head is **masked**.
892
-
893
- inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_dim)`, *optional*):
894
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
895
- is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
896
- model's internal embedding lookup matrix.
897
- output_attentions (`bool`, *optional*):
898
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
899
- tensors for more detail.
900
- output_hidden_states (`bool`, *optional*):
901
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
902
- more detail.
903
- return_dict (`bool`, *optional*):
904
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
905
- """
906
-
907
-
908
- @add_start_docstrings(
909
- "The bare JetMoE Model outputting raw hidden-states without any specific head on top.",
910
- MODULEFORMER_START_DOCSTRING,
911
- )
912
- class JetMoEModel(JetMoEPreTrainedModel):
913
- """
914
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JetMoEBlock`]
915
-
916
- Args:
917
- config: JetMoEConfig
918
- """
919
-
920
- def __init__(self, config: JetMoEConfig):
921
- super().__init__(config)
922
- self.padding_idx = config.pad_token_id
923
- self.vocab_size = config.vocab_size
924
-
925
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
926
- self.layers = nn.ModuleList(
927
- [JetMoEBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
928
- )
929
- self._attn_implementation = config._attn_implementation
930
- self.norm = JetMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
931
-
932
- self.gradient_checkpointing = False
933
- # Initialize weights and apply final processing
934
- self.post_init()
935
-
936
- def get_input_embeddings(self):
937
- return self.embed_tokens
938
-
939
- def set_input_embeddings(self, value):
940
- self.embed_tokens = value
941
-
942
- @add_start_docstrings_to_model_forward(MODULEFORMER_INPUTS_DOCSTRING)
943
- def forward(
944
- self,
945
- input_ids: torch.LongTensor = None,
946
- attention_mask: Optional[torch.Tensor] = None,
947
- position_ids: Optional[torch.LongTensor] = None,
948
- past_key_values: Optional[List[torch.FloatTensor]] = None,
949
- inputs_embeds: Optional[torch.FloatTensor] = None,
950
- use_cache: Optional[bool] = None,
951
- output_attentions: Optional[bool] = None,
952
- output_hidden_states: Optional[bool] = None,
953
- return_dict: Optional[bool] = None,
954
- ) -> Union[Tuple, BaseModelOutputWithPast]:
955
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
956
- output_hidden_states = (
957
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
958
- )
959
- use_cache = use_cache if use_cache is not None else self.config.use_cache
960
-
961
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
962
-
963
- # retrieve input_ids and inputs_embeds
964
- if input_ids is not None and inputs_embeds is not None:
965
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
966
- elif input_ids is not None:
967
- batch_size, seq_length = input_ids.shape
968
- elif inputs_embeds is not None:
969
- batch_size, seq_length, _ = inputs_embeds.shape
970
- else:
971
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
972
-
973
- if self.gradient_checkpointing and self.training:
974
- if use_cache:
975
- logger.warning_once(
976
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
977
- )
978
- use_cache = False
979
-
980
- past_key_values_length = 0
981
-
982
- if use_cache:
983
- use_legacy_cache = not isinstance(past_key_values, Cache)
984
- if use_legacy_cache:
985
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
986
- past_key_values_length = past_key_values.get_usable_length(seq_length)
987
-
988
- if position_ids is None:
989
- device = input_ids.device if input_ids is not None else inputs_embeds.device
990
- position_ids = torch.arange(
991
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
992
- )
993
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
994
- else:
995
- position_ids = position_ids.view(-1, seq_length).long()
996
-
997
- if inputs_embeds is None:
998
- inputs_embeds = self.embed_tokens(input_ids)
999
-
1000
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1001
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1002
- if is_padding_right:
1003
- raise ValueError(
1004
- "You are attempting to perform batched generation with padding_side='right'"
1005
- " this may lead to unexpected behaviour for Flash Attention version of JetMoE. Make sure to "
1006
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1007
- )
1008
-
1009
- if self._attn_implementation == "flash_attention_2":
1010
- # 2d mask is passed through the layers
1011
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1012
- elif self._attn_implementation == "sdpa" and not output_attentions:
1013
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1014
- # the manual implementation that requires a 4D causal mask in all cases.
1015
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1016
- attention_mask,
1017
- (batch_size, seq_length),
1018
- inputs_embeds,
1019
- past_key_values_length,
1020
- )
1021
- else:
1022
- # 4d mask is passed through the layers
1023
- attention_mask = _prepare_4d_causal_attention_mask(
1024
- attention_mask,
1025
- (batch_size, seq_length),
1026
- inputs_embeds,
1027
- past_key_values_length,
1028
- )
1029
-
1030
- hidden_states = inputs_embeds
1031
-
1032
- # decoder layers
1033
- all_hidden_states = () if output_hidden_states else None
1034
- all_self_attns = () if output_attentions else None
1035
- next_decoder_cache = None
1036
-
1037
- aux_loss = 0
1038
- for decoder_layer in self.layers:
1039
- if output_hidden_states:
1040
- all_hidden_states += (hidden_states,)
1041
-
1042
- # hidden_states: Optional[torch.FloatTensor],
1043
- # position_ids: Optional[torch.LongTensor] = None,
1044
- # past_key_value: Optional[Tuple[torch.Tensor]] = None,
1045
- # attention_mask: Optional[torch.FloatTensor] = None,
1046
- # output_attentions: Optional[bool] = False,
1047
- # use_cache: Optional[bool] = False,
1048
-
1049
- if self.gradient_checkpointing and self.training:
1050
- layer_outputs = self._gradient_checkpointing_func(
1051
- #decoder_layer.__call__,
1052
- decoder_layer,
1053
- hidden_states,
1054
- position_ids,
1055
- past_key_values,
1056
- attention_mask,
1057
- output_attentions,
1058
- use_cache,
1059
- use_reentrant=False,
1060
- )
1061
- else:
1062
- layer_outputs = decoder_layer(
1063
- hidden_states,
1064
- attention_mask=attention_mask,
1065
- position_ids=position_ids,
1066
- past_key_value=past_key_values,
1067
- output_attentions=output_attentions,
1068
- use_cache=use_cache,
1069
- )
1070
-
1071
- hidden_states = layer_outputs[0]
1072
-
1073
- if use_cache:
1074
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1075
-
1076
- if output_attentions:
1077
- all_self_attns += (layer_outputs[1],)
1078
-
1079
- aux_loss += layer_outputs[-1]
1080
-
1081
- hidden_states = self.norm(hidden_states)
1082
-
1083
- # add hidden states from the last decoder layer
1084
- if output_hidden_states:
1085
- all_hidden_states += (hidden_states,)
1086
-
1087
- next_cache = None
1088
- if use_cache:
1089
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1090
-
1091
- if not return_dict:
1092
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1093
- return JetMoEBaseModelOutputWithPast(
1094
- last_hidden_state=hidden_states,
1095
- past_key_values=next_cache,
1096
- hidden_states=all_hidden_states,
1097
- attentions=all_self_attns,
1098
- aux_loss=aux_loss,
1099
- )
1100
-
1101
-
1102
- class JetMoEForCausalLM(JetMoEPreTrainedModel):
1103
- _tied_weights_keys = ["lm_head.weight"]
1104
-
1105
- def __init__(self, config):
1106
- super().__init__(config)
1107
- self.model = JetMoEModel(config)
1108
- self.vocab_size = config.vocab_size
1109
- self.aux_loss_coef = getattr(config, 'aux_loss_coef', 0.01)
1110
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1111
-
1112
- # Initialize weights and apply final processing
1113
- self.post_init()
1114
-
1115
- def get_input_embeddings(self):
1116
- return self.model.embed_tokens
1117
-
1118
- def set_input_embeddings(self, value):
1119
- self.model.embed_tokens = value
1120
-
1121
- def get_output_embeddings(self):
1122
- return self.lm_head
1123
-
1124
- def set_output_embeddings(self, new_embeddings):
1125
- self.lm_head = new_embeddings
1126
-
1127
- def set_decoder(self, decoder):
1128
- self.model = decoder
1129
-
1130
- def get_decoder(self):
1131
- return self.model
1132
-
1133
- @add_start_docstrings_to_model_forward(MODULEFORMER_INPUTS_DOCSTRING)
1134
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1135
- def forward(
1136
- self,
1137
- input_ids: torch.LongTensor = None,
1138
- attention_mask: Optional[torch.Tensor] = None,
1139
- position_ids: Optional[torch.LongTensor] = None,
1140
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1141
- inputs_embeds: Optional[torch.FloatTensor] = None,
1142
- labels: Optional[torch.LongTensor] = None,
1143
- use_cache: Optional[bool] = None,
1144
- output_attentions: Optional[bool] = None,
1145
- output_hidden_states: Optional[bool] = None,
1146
- return_dict: Optional[bool] = None,
1147
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1148
- r"""
1149
- Args:
1150
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1151
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1152
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1153
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1154
-
1155
- Returns:
1156
- """
1157
-
1158
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1159
- output_hidden_states = (
1160
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1161
- )
1162
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1163
-
1164
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1165
- outputs = self.model(
1166
- input_ids=input_ids,
1167
- attention_mask=attention_mask,
1168
- position_ids=position_ids,
1169
- past_key_values=past_key_values,
1170
- inputs_embeds=inputs_embeds,
1171
- use_cache=use_cache,
1172
- output_attentions=output_attentions,
1173
- output_hidden_states=output_hidden_states,
1174
- return_dict=return_dict,
1175
- )
1176
-
1177
- hidden_states = outputs[0]
1178
- logits = self.lm_head(hidden_states)
1179
- logits = logits.float()
1180
-
1181
- loss = None
1182
- if labels is not None:
1183
- # Shift so that tokens < n predict n
1184
- shift_logits = logits[..., :-1, :].contiguous()
1185
- shift_labels = labels[..., 1:].contiguous()
1186
- # Flatten the tokens
1187
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1188
- shift_labels = shift_labels.view(-1)
1189
- # Ensure tensors are on the same device
1190
- shift_labels = shift_labels.to(shift_logits.device)
1191
- loss_fct = CrossEntropyLoss()
1192
- loss = loss_fct(shift_logits, shift_labels)
1193
-
1194
- if not return_dict:
1195
- output = (logits,) + outputs[1:]
1196
- return (loss,) + output if loss is not None else output
1197
-
1198
- if labels is not None and self.model.training:
1199
- loss += self.aux_loss_coef * outputs.aux_loss.to(loss.device)
1200
-
1201
- return JetMoECausalLMOutputWithPast(
1202
- loss=loss,
1203
- logits=logits,
1204
- past_key_values=outputs.past_key_values,
1205
- hidden_states=outputs.hidden_states,
1206
- attentions=outputs.attentions,
1207
- aux_loss=outputs.aux_loss,
1208
- )
1209
-
1210
- def prepare_inputs_for_generation(
1211
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1212
- ):
1213
- # Omit tokens covered by past_key_values
1214
- if past_key_values is not None:
1215
- if isinstance(past_key_values, Cache):
1216
- cache_length = past_key_values.get_seq_length()
1217
- past_length = past_key_values.seen_tokens
1218
- max_cache_length = past_key_values.get_max_length()
1219
- else:
1220
- cache_length = past_length = past_key_values[0][0].shape[2]
1221
- max_cache_length = None
1222
-
1223
- # Keep only the unprocessed tokens:
1224
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1225
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1226
- # input)
1227
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1228
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1229
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1230
- # input_ids based on the past_length.
1231
- elif past_length < input_ids.shape[1]:
1232
- input_ids = input_ids[:, past_length:]
1233
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1234
-
1235
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1236
- if (
1237
- max_cache_length is not None
1238
- and attention_mask is not None
1239
- and cache_length + input_ids.shape[1] > max_cache_length
1240
- ):
1241
- attention_mask = attention_mask[:, -max_cache_length:]
1242
-
1243
- position_ids = kwargs.get("position_ids", None)
1244
- if attention_mask is not None and position_ids is None:
1245
- # create position_ids on the fly for batch generation
1246
- position_ids = attention_mask.long().cumsum(-1) - 1
1247
- position_ids.masked_fill_(attention_mask == 0, 1)
1248
- if past_key_values:
1249
- position_ids = position_ids[:, -input_ids.shape[1] :]
1250
-
1251
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1252
- if inputs_embeds is not None and past_key_values is None:
1253
- model_inputs = {"inputs_embeds": inputs_embeds}
1254
- else:
1255
- model_inputs = {"input_ids": input_ids}
1256
-
1257
- model_inputs.update(
1258
- {
1259
- "position_ids": position_ids,
1260
- "past_key_values": past_key_values,
1261
- "use_cache": kwargs.get("use_cache"),
1262
- "attention_mask": attention_mask,
1263
- }
1264
- )
1265
- return model_inputs
1266
-
1267
- @staticmethod
1268
- def _reorder_cache(past_key_values, beam_idx):
1269
- reordered_past = ()
1270
- for layer_past in past_key_values:
1271
- reordered_past += (
1272
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1273
- )
1274
- return reordered_past
1275
-
1276
-
1277
- @add_start_docstrings(
1278
- """
1279
- The JetMoE Model transformer with a sequence classification head on top (linear layer).
1280
-
1281
- [`JetMoEForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1282
- (e.g. GPT-2) do.
1283
-
1284
- Since it does classification on the last token, it requires to know the position of the last token. If a
1285
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1286
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1287
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1288
- each row of the batch).
1289
- """,
1290
- MODULEFORMER_START_DOCSTRING,
1291
- )
1292
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->JetMoE, LLAMA->MODULEFORMER
1293
- class JetMoEForSequenceClassification(JetMoEPreTrainedModel):
1294
- def __init__(self, config):
1295
- super().__init__(config)
1296
- self.num_labels = config.num_labels
1297
- self.model = JetMoEModel(config)
1298
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1299
-
1300
- # Initialize weights and apply final processing
1301
- self.post_init()
1302
-
1303
- def get_input_embeddings(self):
1304
- return self.model.embed_tokens
1305
-
1306
- def set_input_embeddings(self, value):
1307
- self.model.embed_tokens = value
1308
-
1309
- @add_start_docstrings_to_model_forward(MODULEFORMER_INPUTS_DOCSTRING)
1310
- def forward(
1311
- self,
1312
- input_ids: torch.LongTensor = None,
1313
- attention_mask: Optional[torch.Tensor] = None,
1314
- position_ids: Optional[torch.LongTensor] = None,
1315
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1316
- inputs_embeds: Optional[torch.FloatTensor] = None,
1317
- labels: Optional[torch.LongTensor] = None,
1318
- use_cache: Optional[bool] = None,
1319
- output_attentions: Optional[bool] = None,
1320
- output_hidden_states: Optional[bool] = None,
1321
- return_dict: Optional[bool] = None,
1322
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1323
- r"""
1324
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1325
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1326
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1327
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1328
- """
1329
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1330
-
1331
- transformer_outputs = self.model(
1332
- input_ids,
1333
- attention_mask=attention_mask,
1334
- position_ids=position_ids,
1335
- past_key_values=past_key_values,
1336
- inputs_embeds=inputs_embeds,
1337
- use_cache=use_cache,
1338
- output_attentions=output_attentions,
1339
- output_hidden_states=output_hidden_states,
1340
- return_dict=return_dict,
1341
- )
1342
- hidden_states = transformer_outputs[0]
1343
- logits = self.score(hidden_states)
1344
-
1345
- if input_ids is not None:
1346
- batch_size = input_ids.shape[0]
1347
- else:
1348
- batch_size = inputs_embeds.shape[0]
1349
-
1350
- if self.config.pad_token_id is None and batch_size != 1:
1351
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1352
- if self.config.pad_token_id is None:
1353
- sequence_lengths = -1
1354
- else:
1355
- if input_ids is not None:
1356
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1357
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1358
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1359
- sequence_lengths = sequence_lengths.to(logits.device)
1360
- else:
1361
- sequence_lengths = -1
1362
-
1363
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1364
-
1365
- loss = None
1366
- if labels is not None:
1367
- labels = labels.to(logits.device)
1368
- if self.config.problem_type is None:
1369
- if self.num_labels == 1:
1370
- self.config.problem_type = "regression"
1371
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1372
- self.config.problem_type = "single_label_classification"
1373
- else:
1374
- self.config.problem_type = "multi_label_classification"
1375
-
1376
- if self.config.problem_type == "regression":
1377
- loss_fct = MSELoss()
1378
- if self.num_labels == 1:
1379
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1380
- else:
1381
- loss = loss_fct(pooled_logits, labels)
1382
- elif self.config.problem_type == "single_label_classification":
1383
- loss_fct = CrossEntropyLoss()
1384
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1385
- elif self.config.problem_type == "multi_label_classification":
1386
- loss_fct = BCEWithLogitsLoss()
1387
- loss = loss_fct(pooled_logits, labels)
1388
- if not return_dict:
1389
- output = (pooled_logits,) + transformer_outputs[1:]
1390
- return ((loss,) + output) if loss is not None else output
1391
-
1392
- return JetMoESequenceClassifierOutputWithPast(
1393
- loss=loss,
1394
- logits=pooled_logits,
1395
- past_key_values=transformer_outputs.past_key_values,
1396
- hidden_states=transformer_outputs.hidden_states,
1397
- attentions=transformer_outputs.attentions,
1398
- aux_loss=transformer_outputs.aux_loss,
1399
- )