hiyouga commited on
Commit
1f41c4b
1 Parent(s): 43fb20e

Take input attention masks to support left-padded sequences

Browse files

The previous implementation does not accept attention masks as inputs, so it will cause some unexpected behaviours at batched inference (commonly using left-padding). So I reimplemented the alibi encodings to take attention masks in user inputs. Note that this implementation largely depends on [1].

[1] https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py

Files changed (1) hide show
  1. modeling_baichuan.py +248 -125
modeling_baichuan.py CHANGED
@@ -5,81 +5,127 @@ from typing import List, Optional, Tuple, Union
5
 
6
  import torch
7
  import torch.utils.checkpoint
 
 
8
  from torch.nn import CrossEntropyLoss
9
  from transformers import PreTrainedModel
10
  from transformers.activations import ACT2FN
11
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
12
  from transformers.utils import logging
13
- from transformers.generation.utils import GenerationConfig
14
 
15
  from .configuration_baichuan import BaichuanConfig
16
 
 
17
  logger = logging.get_logger(__name__)
18
 
19
- def _get_interleave(n):
20
- def _get_interleave_power_of_2(n):
21
- start = (2 ** (-2 ** -(math.log2(n) - 3)))
22
- ratio = start
23
- return [start * ratio ** i for i in range(n)]
24
-
25
- if math.log2(n).is_integer():
26
- return _get_interleave_power_of_2(n)
27
- else:
28
- closest_power_of_2 = 2 ** math.floor(math.log2(n))
29
- return _get_interleave_power_of_2(closest_power_of_2) + \
30
- _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
31
-
32
- def _fill_with_neg_inf(t):
33
- """FP16-compatible function that fills a tensor with -inf."""
34
- return t.float().fill_(float("-inf")).type_as(t)
35
-
36
- def _gen_alibi_mask(n_head, max_pos):
37
- slopes = torch.Tensor(_get_interleave(n_head))
38
- alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
39
- n_head, -1, -1)
40
- alibi = alibi.view(n_head, 1, max_pos)
41
- alibi_mask = torch.triu(
42
- _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  )
44
- alibi_mask = alibi_mask.unsqueeze(0) + alibi
45
- return alibi_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
47
 
48
- class RMSNorm(torch.nn.Module):
49
  def __init__(self, hidden_size, epsilon=1e-6):
50
  super().__init__()
51
- self.weight = torch.nn.Parameter(torch.empty(hidden_size))
52
  self.epsilon = epsilon
53
 
54
- def forward(self, hidden_states):
 
55
  variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
56
  hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
57
 
58
- # convert into half-precision
59
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
60
- hidden_states = hidden_states.to(self.weight.dtype)
61
 
62
- return self.weight * hidden_states
63
 
 
64
 
65
- class MLP(torch.nn.Module):
66
  def __init__(
67
- self,
68
- hidden_size: int,
69
- intermediate_size: int,
70
- hidden_act: str,
71
  ):
72
  super().__init__()
73
- self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
74
- self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
75
- self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
76
  self.act_fn = ACT2FN[hidden_act]
77
 
78
  def forward(self, x):
79
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
80
 
81
 
82
- class BaichuanAttention(torch.nn.Module):
83
 
84
  def __init__(self, config: BaichuanConfig):
85
  super().__init__()
@@ -93,62 +139,89 @@ class BaichuanAttention(torch.nn.Module):
93
  raise ValueError(
94
  f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
95
  )
96
- self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
97
- self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
 
 
 
 
98
 
99
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
100
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
101
 
102
  def forward(
103
- self,
104
- hidden_states: torch.Tensor,
105
- attention_mask: Optional[torch.Tensor] = None,
106
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
107
- output_attentions: bool = False,
108
- use_cache: bool = False,
 
109
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
110
 
111
  bsz, q_len, _ = hidden_states.size()
112
 
113
- proj = self.W_pack(hidden_states)
114
  proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
115
- query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
116
- key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
117
- value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
118
 
119
- kv_seq_len = key_states.shape[-2]
120
- if past_key_value is not None:
121
- kv_seq_len += past_key_value[0].shape[-2]
122
 
123
  if past_key_value is not None:
124
  # reuse k, v, self_attention
125
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
126
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
 
 
127
 
128
  past_key_value = (key_states, value_states) if use_cache else None
129
 
130
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
 
 
 
 
131
 
132
- if attention_mask is not None:
133
- if attn_weights.size(-2) == 1:
134
- attention_mask = attention_mask[:, -1:, :]
135
- attn_weights = attn_weights + attention_mask.unsqueeze(0)
136
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
137
 
138
- attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
139
- attn_output = torch.matmul(attn_weights, value_states)
 
 
 
 
 
 
140
 
141
- attn_output = attn_output.transpose(1, 2)
142
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
 
 
 
 
 
 
143
  attn_output = self.o_proj(attn_output)
144
 
145
  if not output_attentions:
146
- attn_weights = None
147
 
148
- return attn_output, attn_weights, past_key_value
149
 
150
 
151
- class BaichuanLayer(torch.nn.Module):
 
152
  def __init__(self, config: BaichuanConfig):
153
  super().__init__()
154
  self.hidden_size = config.hidden_size
@@ -162,12 +235,13 @@ class BaichuanLayer(torch.nn.Module):
162
  self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
163
 
164
  def forward(
165
- self,
166
- hidden_states: torch.Tensor,
167
- attention_mask: Optional[torch.Tensor] = None,
168
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
169
- output_attentions: Optional[bool] = False,
170
- use_cache: Optional[bool] = False,
 
171
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
172
 
173
  residual = hidden_states
@@ -177,6 +251,7 @@ class BaichuanLayer(torch.nn.Module):
177
  # Self Attention
178
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
179
  hidden_states=hidden_states,
 
180
  attention_mask=attention_mask,
181
  past_key_value=past_key_value,
182
  output_attentions=output_attentions,
@@ -192,6 +267,9 @@ class BaichuanLayer(torch.nn.Module):
192
 
193
  outputs = (hidden_states,)
194
 
 
 
 
195
  if use_cache:
196
  outputs += (present_key_value,)
197
 
@@ -203,15 +281,16 @@ class BaichuanPreTrainedModel(PreTrainedModel):
203
  base_model_prefix = "model"
204
  supports_gradient_checkpointing = True
205
  _no_split_modules = ["BaichuanLayer"]
 
206
  _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
207
 
208
  def _init_weights(self, module):
209
  std = self.config.initializer_range
210
- if isinstance(module, torch.nn.Linear):
211
  module.weight.data.normal_(mean=0.0, std=std)
212
  if module.bias is not None:
213
  module.bias.data.zero_()
214
- elif isinstance(module, torch.nn.Embedding):
215
  module.weight.data.normal_(mean=0.0, std=std)
216
  if module.padding_idx is not None:
217
  module.weight.data[module.padding_idx].zero_()
@@ -221,49 +300,69 @@ class BaichuanPreTrainedModel(PreTrainedModel):
221
  module.gradient_checkpointing = value
222
 
223
 
224
-
225
  class BaichuanModel(BaichuanPreTrainedModel):
 
226
  def __init__(self, config: BaichuanConfig):
227
  super().__init__(config)
228
  self.padding_idx = config.pad_token_id
229
  self.vocab_size = config.vocab_size
230
  self.n_head = config.num_attention_heads
231
- self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
232
- self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
 
233
  self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
234
 
235
  self.gradient_checkpointing = config.gradient_checkpointing
236
  self.post_init()
237
- self.max_cache_pos = config.model_max_length
238
- self.first_run = True
239
 
240
  def get_input_embeddings(self):
241
  return self.embed_tokens
242
 
243
  def set_input_embeddings(self, value):
244
  self.embed_tokens = value
245
-
246
- def get_alibi_mask(self, tensor, seq_length_with_past):
247
- if self.first_run:
248
- self.first_run = False
249
- self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
250
- if seq_length_with_past > self.max_cache_pos:
251
- self.max_cache_pos = seq_length_with_past
252
- self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
253
- mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
254
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def forward(
257
- self,
258
- input_ids: torch.LongTensor = None,
259
- past_key_values: Optional[List[torch.FloatTensor]] = None,
260
- inputs_embeds: Optional[torch.FloatTensor] = None,
261
- use_cache: Optional[bool] = False,
262
- output_attentions: Optional[bool] = False,
263
- output_hidden_states: Optional[bool] = False,
264
- return_dict: Optional[bool] = True,
 
265
  ) -> Union[Tuple, BaseModelOutputWithPast]:
266
-
 
 
 
 
 
267
 
268
  if input_ids is not None and inputs_embeds is not None:
269
  raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
@@ -275,19 +374,21 @@ class BaichuanModel(BaichuanPreTrainedModel):
275
  raise ValueError("You need to provide input_ids or inputs_embeds")
276
 
277
  seq_length_with_past = seq_length
278
-
279
  if past_key_values is not None:
280
- past_key_values_length = past_key_values[0][0].shape[2]
281
  seq_length_with_past = seq_length_with_past + past_key_values_length
282
 
283
  if inputs_embeds is None:
284
  inputs_embeds = self.embed_tokens(input_ids)
285
 
286
- # embed positions
287
- attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
288
-
289
  hidden_states = inputs_embeds
290
 
 
 
 
 
 
291
  if self.gradient_checkpointing and self.training:
292
  if use_cache:
293
  logger.warning_once(
@@ -295,6 +396,15 @@ class BaichuanModel(BaichuanPreTrainedModel):
295
  )
296
  use_cache = False
297
 
 
 
 
 
 
 
 
 
 
298
  # decoder layers
299
  all_hidden_states = () if output_hidden_states else None
300
  all_self_attns = () if output_attentions else None
@@ -318,13 +428,15 @@ class BaichuanModel(BaichuanPreTrainedModel):
318
  layer_outputs = torch.utils.checkpoint.checkpoint(
319
  create_custom_forward(decoder_layer),
320
  hidden_states,
321
- attention_mask,
 
322
  None,
323
  )
324
  else:
325
  layer_outputs = decoder_layer(
326
  hidden_states,
327
- attention_mask=attention_mask,
 
328
  past_key_value=past_key_value,
329
  output_attentions=output_attentions,
330
  use_cache=use_cache,
@@ -345,8 +457,10 @@ class BaichuanModel(BaichuanPreTrainedModel):
345
  all_hidden_states += (hidden_states,)
346
 
347
  next_cache = next_decoder_cache if use_cache else None
 
348
  if not return_dict:
349
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
350
  return BaseModelOutputWithPast(
351
  last_hidden_state=hidden_states,
352
  past_key_values=next_cache,
@@ -356,10 +470,12 @@ class BaichuanModel(BaichuanPreTrainedModel):
356
 
357
 
358
  class BaichuanForCausalLM(BaichuanPreTrainedModel):
 
359
  def __init__(self, config):
360
  super().__init__(config)
361
  self.model = BaichuanModel(config)
362
- self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
363
 
364
  # Initialize weights and apply final processing
365
  self.post_init()
@@ -383,22 +499,28 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
383
  return self.model
384
 
385
  def forward(
386
- self,
387
- input_ids: torch.LongTensor = None,
388
- past_key_values: Optional[List[torch.FloatTensor]] = None,
389
- inputs_embeds: Optional[torch.FloatTensor] = None,
390
- labels: Optional[torch.LongTensor] = None,
391
- use_cache: Optional[bool] = None,
392
- output_attentions: Optional[bool] = False,
393
- output_hidden_states: Optional[bool] = False,
394
- return_dict: Optional[bool] = True,
395
- **kwargs
 
396
  ) -> Union[Tuple, CausalLMOutputWithPast]:
397
-
 
 
 
 
398
 
399
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
400
  outputs = self.model(
401
  input_ids=input_ids,
 
402
  past_key_values=past_key_values,
403
  inputs_embeds=inputs_embeds,
404
  use_cache=use_cache,
@@ -436,7 +558,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
436
  )
437
 
438
  def prepare_inputs_for_generation(
439
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
440
  ):
441
  if past_key_values:
442
  input_ids = input_ids[:, -1:]
@@ -451,6 +573,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
451
  {
452
  "past_key_values": past_key_values,
453
  "use_cache": kwargs.get("use_cache"),
 
454
  }
455
  )
456
  return model_inputs
 
5
 
6
  import torch
7
  import torch.utils.checkpoint
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
  from torch.nn import CrossEntropyLoss
11
  from transformers import PreTrainedModel
12
  from transformers.activations import ACT2FN
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
  from transformers.utils import logging
 
15
 
16
  from .configuration_baichuan import BaichuanConfig
17
 
18
+
19
  logger = logging.get_logger(__name__)
20
 
21
+
22
+ # Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
23
+ def _make_causal_mask(
24
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
25
+ ) -> torch.BoolTensor:
26
+ """
27
+ Make causal mask used for self-attention.
28
+ """
29
+ batch_size, target_length = input_ids_shape
30
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
31
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
32
+ seq_ids = torch.arange(target_length, device=device)
33
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
34
+
35
+ if past_key_values_length > 0:
36
+ mask[:, :past_key_values_length] = False
37
+
38
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
39
+ return expanded_mask
40
+
41
+
42
+ # Copied from transformers.models.bloom.modeling_bloom._expand_mask
43
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
44
+ """
45
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
46
+ """
47
+ batch_size, src_length = mask.shape
48
+ tgt_length = tgt_length if tgt_length is not None else src_length
49
+
50
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
51
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
52
+
53
+
54
+ # Copied from transformers.models.bloom.modeling_bloom.build_alibi_tensor
55
+ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
56
+ """
57
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
58
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
59
+ `softmax(l+a) = softmax(l)`.
60
+
61
+ Args:
62
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
63
+ attention_mask (`torch.Tensor`):
64
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
65
+ num_heads (`int`, *required*):
66
+ number of heads
67
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
68
+ dtype of the output tensor
69
+ """
70
+ batch_size, seq_length = attention_mask.shape
71
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
72
+ base = torch.tensor(
73
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
74
  )
75
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
76
+ slopes = torch.pow(base, powers)
77
+
78
+ if closest_power_of_2 != num_heads:
79
+ extra_base = torch.tensor(
80
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
81
+ )
82
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
83
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
84
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
85
+
86
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
87
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
88
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
89
+ # => the query_length dimension will then be broadcasted correctly
90
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
91
+ alibi = slopes[..., None] * arange_tensor
92
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
93
+
94
 
95
+ class RMSNorm(nn.Module):
96
 
 
97
  def __init__(self, hidden_size, epsilon=1e-6):
98
  super().__init__()
99
+ self.weight = nn.Parameter(torch.ones(hidden_size))
100
  self.epsilon = epsilon
101
 
102
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103
+ input_dtype = hidden_states.dtype
104
  variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
105
  hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
106
 
107
+ return (self.weight * hidden_states).to(input_dtype)
 
 
108
 
 
109
 
110
+ class MLP(nn.Module):
111
 
 
112
  def __init__(
113
+ self,
114
+ hidden_size: int,
115
+ intermediate_size: int,
116
+ hidden_act: str,
117
  ):
118
  super().__init__()
119
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
120
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
121
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
122
  self.act_fn = ACT2FN[hidden_act]
123
 
124
  def forward(self, x):
125
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
126
 
127
 
128
+ class BaichuanAttention(nn.Module):
129
 
130
  def __init__(self, config: BaichuanConfig):
131
  super().__init__()
 
139
  raise ValueError(
140
  f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
141
  )
142
+
143
+ # Layer-wise attention scaling
144
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
145
+ self.beta = 1.0
146
+
147
+ self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
148
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
149
 
150
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
151
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
152
 
153
  def forward(
154
+ self,
155
+ hidden_states: torch.Tensor,
156
+ alibi: torch.Tensor,
157
+ attention_mask: torch.Tensor,
158
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
159
+ output_attentions: bool = False,
160
+ use_cache: bool = False,
161
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
162
 
163
  bsz, q_len, _ = hidden_states.size()
164
 
165
+ proj = self.W_pack(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
166
  proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
167
+ query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim)
168
+ key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim)
169
+ value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim)
170
 
171
+ query_states = query_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
172
+ key_states = key_states.permute(0, 2, 3, 1).reshape(bsz * self.num_heads, self.head_dim, q_len)
173
+ value_states = value_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
174
 
175
  if past_key_value is not None:
176
  # reuse k, v, self_attention
177
+ past_key, past_value = past_key_value
178
+ key_states = torch.cat([past_key, key_states], dim=2)
179
+ value_states = torch.cat([past_value, value_states], dim=1)
180
+
181
+ _, _, kv_seq_len = key_states.shape
182
 
183
  past_key_value = (key_states, value_states) if use_cache else None
184
 
185
+ # [batch_size * num_heads, q_length, kv_length]
186
+ # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
187
+ matmul_result = alibi.baddbmm(
188
+ batch1=query_states,
189
+ batch2=key_states,
190
+ beta=self.beta,
191
+ alpha=self.inv_norm_factor,
192
+ )
193
 
194
+ # change view to [batch_size, num_heads, q_length, kv_length]
195
+ attention_scores = matmul_result.view(bsz, self.num_heads, q_len, kv_seq_len)
 
 
 
196
 
197
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
198
+ # [batch_size, num_heads, q_length, kv_length]
199
+ input_dtype = attention_scores.dtype
200
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
201
+ if input_dtype == torch.float16:
202
+ attention_scores = attention_scores.to(torch.float)
203
+ attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
204
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
205
 
206
+ # change view [batch_size x num_heads, q_length, kv_length]
207
+ attention_probs_reshaped = attention_probs.view(bsz * self.num_heads, q_len, kv_seq_len)
208
+
209
+ # matmul: [batch_size * num_heads, q_length, head_dim]
210
+ attn_output = torch.bmm(attention_probs_reshaped, value_states)
211
+
212
+ attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
213
+
214
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
215
  attn_output = self.o_proj(attn_output)
216
 
217
  if not output_attentions:
218
+ attention_probs = None
219
 
220
+ return attn_output, attention_probs, past_key_value
221
 
222
 
223
+ class BaichuanLayer(nn.Module):
224
+
225
  def __init__(self, config: BaichuanConfig):
226
  super().__init__()
227
  self.hidden_size = config.hidden_size
 
235
  self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
236
 
237
  def forward(
238
+ self,
239
+ hidden_states: torch.Tensor,
240
+ alibi: torch.Tensor,
241
+ attention_mask: torch.Tensor,
242
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
243
+ output_attentions: Optional[bool] = False,
244
+ use_cache: Optional[bool] = False,
245
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
246
 
247
  residual = hidden_states
 
251
  # Self Attention
252
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
253
  hidden_states=hidden_states,
254
+ alibi=alibi,
255
  attention_mask=attention_mask,
256
  past_key_value=past_key_value,
257
  output_attentions=output_attentions,
 
267
 
268
  outputs = (hidden_states,)
269
 
270
+ if output_attentions:
271
+ outputs += (self_attn_weights,)
272
+
273
  if use_cache:
274
  outputs += (present_key_value,)
275
 
 
281
  base_model_prefix = "model"
282
  supports_gradient_checkpointing = True
283
  _no_split_modules = ["BaichuanLayer"]
284
+ _skip_keys_device_placement = "past_key_values"
285
  _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
286
 
287
  def _init_weights(self, module):
288
  std = self.config.initializer_range
289
+ if isinstance(module, nn.Linear):
290
  module.weight.data.normal_(mean=0.0, std=std)
291
  if module.bias is not None:
292
  module.bias.data.zero_()
293
+ elif isinstance(module, nn.Embedding):
294
  module.weight.data.normal_(mean=0.0, std=std)
295
  if module.padding_idx is not None:
296
  module.weight.data[module.padding_idx].zero_()
 
300
  module.gradient_checkpointing = value
301
 
302
 
 
303
  class BaichuanModel(BaichuanPreTrainedModel):
304
+
305
  def __init__(self, config: BaichuanConfig):
306
  super().__init__(config)
307
  self.padding_idx = config.pad_token_id
308
  self.vocab_size = config.vocab_size
309
  self.n_head = config.num_attention_heads
310
+
311
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
312
+ self.layers = nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
313
  self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
314
 
315
  self.gradient_checkpointing = config.gradient_checkpointing
316
  self.post_init()
 
 
317
 
318
  def get_input_embeddings(self):
319
  return self.embed_tokens
320
 
321
  def set_input_embeddings(self, value):
322
  self.embed_tokens = value
323
+
324
+ def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
325
+ return build_alibi_tensor(attention_mask, num_heads, dtype)
326
+
327
+ def _prepare_attn_mask(
328
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
329
+ ) -> torch.BoolTensor:
330
+ # create causal mask
331
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
332
+ combined_attention_mask = None
333
+ device = attention_mask.device
334
+ _, src_length = input_shape
335
+
336
+ if src_length > 1:
337
+ combined_attention_mask = _make_causal_mask(
338
+ input_shape, device=device, past_key_values_length=past_key_values_length
339
+ )
340
+
341
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
342
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
343
+ combined_attention_mask = (
344
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
345
+ )
346
+
347
+ return combined_attention_mask
348
 
349
  def forward(
350
+ self,
351
+ input_ids: torch.LongTensor = None,
352
+ attention_mask: Optional[torch.Tensor] = None,
353
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
354
+ inputs_embeds: Optional[torch.FloatTensor] = None,
355
+ use_cache: Optional[bool] = None,
356
+ output_attentions: Optional[bool] = None,
357
+ output_hidden_states: Optional[bool] = None,
358
+ return_dict: Optional[bool] = None,
359
  ) -> Union[Tuple, BaseModelOutputWithPast]:
360
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
361
+ output_hidden_states = (
362
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
363
+ )
364
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
 
367
  if input_ids is not None and inputs_embeds is not None:
368
  raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
 
374
  raise ValueError("You need to provide input_ids or inputs_embeds")
375
 
376
  seq_length_with_past = seq_length
377
+ past_key_values_length = 0
378
  if past_key_values is not None:
379
+ past_key_values_length = past_key_values[0][0].shape[1]
380
  seq_length_with_past = seq_length_with_past + past_key_values_length
381
 
382
  if inputs_embeds is None:
383
  inputs_embeds = self.embed_tokens(input_ids)
384
 
 
 
 
385
  hidden_states = inputs_embeds
386
 
387
+ if attention_mask is None:
388
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
389
+ else:
390
+ attention_mask = attention_mask.to(hidden_states.device)
391
+
392
  if self.gradient_checkpointing and self.training:
393
  if use_cache:
394
  logger.warning_once(
 
396
  )
397
  use_cache = False
398
 
399
+ # Compute alibi tensor: check build_alibi_tensor documentation
400
+ alibi = self.build_alibi_tensor(attention_mask, self.n_head, dtype=hidden_states.dtype)
401
+
402
+ causal_mask = self._prepare_attn_mask(
403
+ attention_mask,
404
+ input_shape=(batch_size, seq_length),
405
+ past_key_values_length=past_key_values_length,
406
+ )
407
+
408
  # decoder layers
409
  all_hidden_states = () if output_hidden_states else None
410
  all_self_attns = () if output_attentions else None
 
428
  layer_outputs = torch.utils.checkpoint.checkpoint(
429
  create_custom_forward(decoder_layer),
430
  hidden_states,
431
+ alibi,
432
+ causal_mask,
433
  None,
434
  )
435
  else:
436
  layer_outputs = decoder_layer(
437
  hidden_states,
438
+ alibi=alibi,
439
+ attention_mask=causal_mask,
440
  past_key_value=past_key_value,
441
  output_attentions=output_attentions,
442
  use_cache=use_cache,
 
457
  all_hidden_states += (hidden_states,)
458
 
459
  next_cache = next_decoder_cache if use_cache else None
460
+
461
  if not return_dict:
462
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
463
+
464
  return BaseModelOutputWithPast(
465
  last_hidden_state=hidden_states,
466
  past_key_values=next_cache,
 
470
 
471
 
472
  class BaichuanForCausalLM(BaichuanPreTrainedModel):
473
+
474
  def __init__(self, config):
475
  super().__init__(config)
476
  self.model = BaichuanModel(config)
477
+
478
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
479
 
480
  # Initialize weights and apply final processing
481
  self.post_init()
 
499
  return self.model
500
 
501
  def forward(
502
+ self,
503
+ input_ids: torch.LongTensor = None,
504
+ attention_mask: Optional[torch.Tensor] = None,
505
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
506
+ inputs_embeds: Optional[torch.FloatTensor] = None,
507
+ labels: Optional[torch.LongTensor] = None,
508
+ use_cache: Optional[bool] = None,
509
+ output_attentions: Optional[bool] = None,
510
+ output_hidden_states: Optional[bool] = None,
511
+ return_dict: Optional[bool] = None,
512
+ **kwargs
513
  ) -> Union[Tuple, CausalLMOutputWithPast]:
514
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
515
+ output_hidden_states = (
516
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
517
+ )
518
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
519
 
520
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
521
  outputs = self.model(
522
  input_ids=input_ids,
523
+ attention_mask=attention_mask,
524
  past_key_values=past_key_values,
525
  inputs_embeds=inputs_embeds,
526
  use_cache=use_cache,
 
558
  )
559
 
560
  def prepare_inputs_for_generation(
561
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
562
  ):
563
  if past_key_values:
564
  input_ids = input_ids[:, -1:]
 
573
  {
574
  "past_key_values": past_key_values,
575
  "use_cache": kwargs.get("use_cache"),
576
+ "attention_mask": attention_mask,
577
  }
578
  )
579
  return model_inputs