GradientGuru commited on
Commit
d1816c6
1 Parent(s): 43fb20e

fix alibi problem and support attention mask

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +72 -26
modeling_baichuan.py CHANGED
@@ -16,6 +16,7 @@ from .configuration_baichuan import BaichuanConfig
16
 
17
  logger = logging.get_logger(__name__)
18
 
 
19
  def _get_interleave(n):
20
  def _get_interleave_power_of_2(n):
21
  start = (2 ** (-2 ** -(math.log2(n) - 3)))
@@ -44,6 +45,16 @@ def _gen_alibi_mask(n_head, max_pos):
44
  alibi_mask = alibi_mask.unsqueeze(0) + alibi
45
  return alibi_mask
46
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  class RMSNorm(torch.nn.Module):
49
  def __init__(self, hidden_size, epsilon=1e-6):
@@ -80,7 +91,6 @@ class MLP(torch.nn.Module):
80
 
81
 
82
  class BaichuanAttention(torch.nn.Module):
83
-
84
  def __init__(self, config: BaichuanConfig):
85
  super().__init__()
86
  self.config = config
@@ -130,12 +140,16 @@ class BaichuanAttention(torch.nn.Module):
130
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
131
 
132
  if attention_mask is not None:
133
- if attn_weights.size(-2) == 1:
134
- attention_mask = attention_mask[:, -1:, :]
135
- attn_weights = attn_weights + attention_mask.unsqueeze(0)
 
 
 
136
  attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
137
 
138
  attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 
139
  attn_output = torch.matmul(attn_weights, value_states)
140
 
141
  attn_output = attn_output.transpose(1, 2)
@@ -239,23 +253,32 @@ class BaichuanModel(BaichuanPreTrainedModel):
239
 
240
  def get_input_embeddings(self):
241
  return self.embed_tokens
242
-
243
  def set_input_embeddings(self, value):
244
- self.embed_tokens = value
245
-
246
  def get_alibi_mask(self, tensor, seq_length_with_past):
247
- if self.first_run:
248
- self.first_run = False
249
- self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
250
- if seq_length_with_past > self.max_cache_pos:
251
- self.max_cache_pos = seq_length_with_past
252
- self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
253
- mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
 
 
 
 
 
 
 
 
254
  return mask
255
 
256
  def forward(
257
  self,
258
  input_ids: torch.LongTensor = None,
 
259
  past_key_values: Optional[List[torch.FloatTensor]] = None,
260
  inputs_embeds: Optional[torch.FloatTensor] = None,
261
  use_cache: Optional[bool] = False,
@@ -283,8 +306,23 @@ class BaichuanModel(BaichuanPreTrainedModel):
283
  if inputs_embeds is None:
284
  inputs_embeds = self.embed_tokens(input_ids)
285
 
286
- # embed positions
287
- attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  hidden_states = inputs_embeds
290
 
@@ -353,7 +391,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
353
  hidden_states=all_hidden_states,
354
  attentions=all_self_attns,
355
  )
356
-
357
 
358
  class BaichuanForCausalLM(BaichuanPreTrainedModel):
359
  def __init__(self, config):
@@ -381,10 +419,11 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
381
 
382
  def get_decoder(self):
383
  return self.model
384
-
385
  def forward(
386
  self,
387
  input_ids: torch.LongTensor = None,
 
388
  past_key_values: Optional[List[torch.FloatTensor]] = None,
389
  inputs_embeds: Optional[torch.FloatTensor] = None,
390
  labels: Optional[torch.LongTensor] = None,
@@ -399,13 +438,14 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
399
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
400
  outputs = self.model(
401
  input_ids=input_ids,
 
402
  past_key_values=past_key_values,
403
  inputs_embeds=inputs_embeds,
404
  use_cache=use_cache,
405
  output_attentions=output_attentions,
406
  output_hidden_states=output_hidden_states,
407
  return_dict=return_dict,
408
- )
409
 
410
  hidden_states = outputs[0]
411
  logits = self.lm_head(hidden_states)
@@ -436,8 +476,13 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
436
  )
437
 
438
  def prepare_inputs_for_generation(
439
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
440
- ):
 
 
 
 
 
441
  if past_key_values:
442
  input_ids = input_ids[:, -1:]
443
 
@@ -448,11 +493,12 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
448
  model_inputs = {"input_ids": input_ids}
449
 
450
  model_inputs.update(
451
- {
452
  "past_key_values": past_key_values,
453
  "use_cache": kwargs.get("use_cache"),
454
- }
455
- )
 
456
  return model_inputs
457
 
458
  @staticmethod
@@ -470,7 +516,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
470
  raise ImportError(
471
  f"Needs QLinear to run quantize."
472
  )
473
-
474
  for layer in self.model.layers:
475
  layer.self_attn.W_pack = QLinear(
476
  bits=bits,
@@ -497,7 +543,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
497
  weight=layer.mlp.up_proj.weight,
498
  bias = None,
499
  )
500
- return self
501
 
502
  def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
503
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
 
16
 
17
  logger = logging.get_logger(__name__)
18
 
19
+
20
  def _get_interleave(n):
21
  def _get_interleave_power_of_2(n):
22
  start = (2 ** (-2 ** -(math.log2(n) - 3)))
 
45
  alibi_mask = alibi_mask.unsqueeze(0) + alibi
46
  return alibi_mask
47
 
48
+ def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
49
+ """for training only"""
50
+ dim = tensor.size(1)
51
+ _future_mask = torch.triu(
52
+ _fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1
53
+ )
54
+ _future_mask = _future_mask.unsqueeze(0) + alibi
55
+ _future_mask = _future_mask.to(tensor)
56
+ return _future_mask[:tensor.shape[0] * attn_heads, :maxpos, :maxpos]
57
+
58
 
59
  class RMSNorm(torch.nn.Module):
60
  def __init__(self, hidden_size, epsilon=1e-6):
 
91
 
92
 
93
  class BaichuanAttention(torch.nn.Module):
 
94
  def __init__(self, config: BaichuanConfig):
95
  super().__init__()
96
  self.config = config
 
140
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
141
 
142
  if attention_mask is not None:
143
+ if q_len == 1: # inference with cache
144
+ if len(attention_mask.size()) == 4:
145
+ attention_mask = attention_mask[:, :, -1:, :]
146
+ else:
147
+ attention_mask = attention_mask[:, -1:, :]
148
+ attn_weights = attn_weights + attention_mask
149
  attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
150
 
151
  attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
152
+
153
  attn_output = torch.matmul(attn_weights, value_states)
154
 
155
  attn_output = attn_output.transpose(1, 2)
 
253
 
254
  def get_input_embeddings(self):
255
  return self.embed_tokens
256
+
257
  def set_input_embeddings(self, value):
258
+ self.embed_tokens = value
259
+
260
  def get_alibi_mask(self, tensor, seq_length_with_past):
261
+ if self.training:
262
+ slopes = torch.Tensor(_get_interleave(self.n_head))
263
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand(
264
+ self.n_head,
265
+ -1, -1)
266
+ alibi = alibi.view(self.n_head, 1, seq_length_with_past)
267
+ mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.n_head)
268
+ else:
269
+ if self.first_run:
270
+ self.first_run = False
271
+ self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
272
+ if seq_length_with_past > self.max_cache_pos:
273
+ self.max_cache_pos = seq_length_with_past
274
+ self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
275
+ mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
276
  return mask
277
 
278
  def forward(
279
  self,
280
  input_ids: torch.LongTensor = None,
281
+ attention_mask: Optional[torch.Tensor] = None,
282
  past_key_values: Optional[List[torch.FloatTensor]] = None,
283
  inputs_embeds: Optional[torch.FloatTensor] = None,
284
  use_cache: Optional[bool] = False,
 
306
  if inputs_embeds is None:
307
  inputs_embeds = self.embed_tokens(input_ids)
308
 
309
+
310
+ alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
311
+ if attention_mask is not None:
312
+ if len(attention_mask.shape) == 2:
313
+ expanded_mask = attention_mask.to(alibi_mask.dtype)
314
+ expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
315
+ ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
316
+ else:
317
+ expanded_mask = attention_mask
318
+ bsz = inputs_embeds.size(0)
319
+ src_len, tgt_len = alibi_mask.size()[-2:]
320
+ expanded_mask = expanded_mask.unsqueeze(1).expand(bsz, 1, src_len, tgt_len).to(alibi_mask.dtype)
321
+ inverted_mask = 1.0 - expanded_mask
322
+ inverted_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min)
323
+ attention_mask = inverted_mask + alibi_mask.unsqueeze(0)
324
+ else:
325
+ attention_mask = alibi_mask
326
 
327
  hidden_states = inputs_embeds
328
 
 
391
  hidden_states=all_hidden_states,
392
  attentions=all_self_attns,
393
  )
394
+
395
 
396
  class BaichuanForCausalLM(BaichuanPreTrainedModel):
397
  def __init__(self, config):
 
419
 
420
  def get_decoder(self):
421
  return self.model
422
+
423
  def forward(
424
  self,
425
  input_ids: torch.LongTensor = None,
426
+ attention_mask: Optional[torch.Tensor] = None,
427
  past_key_values: Optional[List[torch.FloatTensor]] = None,
428
  inputs_embeds: Optional[torch.FloatTensor] = None,
429
  labels: Optional[torch.LongTensor] = None,
 
438
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
439
  outputs = self.model(
440
  input_ids=input_ids,
441
+ attention_mask=attention_mask,
442
  past_key_values=past_key_values,
443
  inputs_embeds=inputs_embeds,
444
  use_cache=use_cache,
445
  output_attentions=output_attentions,
446
  output_hidden_states=output_hidden_states,
447
  return_dict=return_dict,
448
+ )
449
 
450
  hidden_states = outputs[0]
451
  logits = self.lm_head(hidden_states)
 
476
  )
477
 
478
  def prepare_inputs_for_generation(
479
+ self,
480
+ input_ids: torch.LongTensor,
481
+ past_key_values: Optional[torch.Tensor] = None,
482
+ attention_mask: Optional[torch.Tensor] = None,
483
+ inputs_embeds: Optional[torch.Tensor] = None,
484
+ **kwargs
485
+ ):
486
  if past_key_values:
487
  input_ids = input_ids[:, -1:]
488
 
 
493
  model_inputs = {"input_ids": input_ids}
494
 
495
  model_inputs.update(
496
+ {
497
  "past_key_values": past_key_values,
498
  "use_cache": kwargs.get("use_cache"),
499
+ "attention_mask": attention_mask
500
+ }
501
+ )
502
  return model_inputs
503
 
504
  @staticmethod
 
516
  raise ImportError(
517
  f"Needs QLinear to run quantize."
518
  )
519
+
520
  for layer in self.model.layers:
521
  layer.self_attn.W_pack = QLinear(
522
  bits=bits,
 
543
  weight=layer.mlp.up_proj.weight,
544
  bias = None,
545
  )
546
+ return self
547
 
548
  def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
549
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens