hiyouga commited on
Commit
478dc90
1 Parent(s): 3a663f5

update cache format to support contrastive search and beam search

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +83 -18
modeling_baichuan.py CHANGED
@@ -300,6 +300,45 @@ class BaichuanPreTrainedModel(PreTrainedModel):
300
  if isinstance(module, BaichuanModel):
301
  module.gradient_checkpointing = value
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  class BaichuanModel(BaichuanPreTrainedModel):
305
 
@@ -318,9 +357,9 @@ class BaichuanModel(BaichuanPreTrainedModel):
318
 
319
  def get_input_embeddings(self):
320
  return self.embed_tokens
321
-
322
  def set_input_embeddings(self, value):
323
- self.embed_tokens = value
324
 
325
  def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
326
  return build_alibi_tensor(attention_mask, num_heads, dtype)
@@ -468,7 +507,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
468
  hidden_states=all_hidden_states,
469
  attentions=all_self_attns,
470
  )
471
-
472
 
473
  class BaichuanForCausalLM(BaichuanPreTrainedModel):
474
 
@@ -498,7 +537,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
498
 
499
  def get_decoder(self):
500
  return self.model
501
-
502
  def forward(
503
  self,
504
  input_ids: torch.LongTensor = None,
@@ -528,7 +567,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
528
  output_attentions=output_attentions,
529
  output_hidden_states=output_hidden_states,
530
  return_dict=return_dict,
531
- )
532
 
533
  hidden_states = outputs[0]
534
  logits = self.lm_head(hidden_states)
@@ -559,10 +598,19 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
559
  )
560
 
561
  def prepare_inputs_for_generation(
562
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
563
- ):
 
 
 
 
 
564
  if past_key_values:
565
  input_ids = input_ids[:, -1:]
 
 
 
 
566
 
567
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
568
  if inputs_embeds is not None and past_key_values is None:
@@ -571,21 +619,38 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
571
  model_inputs = {"input_ids": input_ids}
572
 
573
  model_inputs.update(
574
- {
575
  "past_key_values": past_key_values,
576
  "use_cache": kwargs.get("use_cache"),
577
  "attention_mask": attention_mask,
578
- }
579
- )
580
  return model_inputs
581
 
582
- @staticmethod
583
- def _reorder_cache(past_key_values, beam_idx):
584
- return tuple(
585
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
586
- for layer_past in past_key_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  )
588
-
589
 
590
  def quantize(self, bits: int):
591
  try:
@@ -594,7 +659,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
594
  raise ImportError(
595
  f"Needs QLinear to run quantize."
596
  )
597
-
598
  for layer in self.model.layers:
599
  layer.self_attn.W_pack = QLinear(
600
  bits=bits,
@@ -621,7 +686,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
621
  weight=layer.mlp.up_proj.weight,
622
  bias = None,
623
  )
624
- return self
625
 
626
  def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
627
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
 
300
  if isinstance(module, BaichuanModel):
301
  module.gradient_checkpointing = value
302
 
303
+ @staticmethod
304
+ def _convert_to_standard_cache(
305
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
306
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
307
+ """
308
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
309
+ num_heads, ...]))
310
+ """
311
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
312
+ num_heads = batch_size_times_num_heads // batch_size
313
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
314
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
315
+ return tuple(
316
+ (
317
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
318
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
319
+ )
320
+ for layer_past in past_key_value
321
+ )
322
+
323
+ @staticmethod
324
+ def _convert_to_baichuan_cache(
325
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
326
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
327
+ """
328
+ Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
329
+ """
330
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
331
+ batch_size_times_num_heads = batch_size * num_heads
332
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
333
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
334
+ return tuple(
335
+ (
336
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
337
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
338
+ )
339
+ for layer_past in past_key_value
340
+ )
341
+
342
 
343
  class BaichuanModel(BaichuanPreTrainedModel):
344
 
 
357
 
358
  def get_input_embeddings(self):
359
  return self.embed_tokens
360
+
361
  def set_input_embeddings(self, value):
362
+ self.embed_tokens = value
363
 
364
  def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
365
  return build_alibi_tensor(attention_mask, num_heads, dtype)
 
507
  hidden_states=all_hidden_states,
508
  attentions=all_self_attns,
509
  )
510
+
511
 
512
  class BaichuanForCausalLM(BaichuanPreTrainedModel):
513
 
 
537
 
538
  def get_decoder(self):
539
  return self.model
540
+
541
  def forward(
542
  self,
543
  input_ids: torch.LongTensor = None,
 
567
  output_attentions=output_attentions,
568
  output_hidden_states=output_hidden_states,
569
  return_dict=return_dict,
570
+ )
571
 
572
  hidden_states = outputs[0]
573
  logits = self.lm_head(hidden_states)
 
598
  )
599
 
600
  def prepare_inputs_for_generation(
601
+ self,
602
+ input_ids: torch.LongTensor,
603
+ past_key_values: Optional[torch.Tensor] = None,
604
+ attention_mask: Optional[torch.Tensor] = None,
605
+ inputs_embeds: Optional[torch.Tensor] = None,
606
+ **kwargs
607
+ ) -> dict:
608
  if past_key_values:
609
  input_ids = input_ids[:, -1:]
610
+
611
+ # the cache may be in the standard format (e.g. in contrastive search)
612
+ if past_key_values[0][0].shape[0] == input_ids.shape[0]:
613
+ past_key_values = self._convert_to_baichuan_cache(past_key_values)
614
 
615
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
616
  if inputs_embeds is not None and past_key_values is None:
 
619
  model_inputs = {"input_ids": input_ids}
620
 
621
  model_inputs.update(
622
+ {
623
  "past_key_values": past_key_values,
624
  "use_cache": kwargs.get("use_cache"),
625
  "attention_mask": attention_mask,
626
+ }
627
+ )
628
  return model_inputs
629
 
630
+ def _reorder_cache(
631
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
632
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
633
+ """
634
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
635
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
636
+ beam_idx at every generation step.
637
+
638
+ Output shares the same memory storage as `past`.
639
+ """
640
+ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
641
+
642
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
643
+ device_to_beam_idx = {
644
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
645
+ }
646
+ reordered_past = tuple(
647
+ (
648
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
649
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
650
+ )
651
+ for layer_past in standardized_past
652
  )
653
+ return self._convert_to_baichuan_cache(reordered_past)
654
 
655
  def quantize(self, bits: int):
656
  try:
 
659
  raise ImportError(
660
  f"Needs QLinear to run quantize."
661
  )
662
+
663
  for layer in self.model.layers:
664
  layer.self_attn.W_pack = QLinear(
665
  bits=bits,
 
686
  weight=layer.mlp.up_proj.weight,
687
  bias = None,
688
  )
689
+ return self
690
 
691
  def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
692
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens