update cache format to support contrastive search and beam search
Browse files- modeling_baichuan.py +83 -18
modeling_baichuan.py
CHANGED
@@ -300,6 +300,45 @@ class BaichuanPreTrainedModel(PreTrainedModel):
|
|
300 |
if isinstance(module, BaichuanModel):
|
301 |
module.gradient_checkpointing = value
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
class BaichuanModel(BaichuanPreTrainedModel):
|
305 |
|
@@ -318,9 +357,9 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
318 |
|
319 |
def get_input_embeddings(self):
|
320 |
return self.embed_tokens
|
321 |
-
|
322 |
def set_input_embeddings(self, value):
|
323 |
-
self.embed_tokens = value
|
324 |
|
325 |
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
326 |
return build_alibi_tensor(attention_mask, num_heads, dtype)
|
@@ -468,7 +507,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
468 |
hidden_states=all_hidden_states,
|
469 |
attentions=all_self_attns,
|
470 |
)
|
471 |
-
|
472 |
|
473 |
class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
474 |
|
@@ -498,7 +537,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
498 |
|
499 |
def get_decoder(self):
|
500 |
return self.model
|
501 |
-
|
502 |
def forward(
|
503 |
self,
|
504 |
input_ids: torch.LongTensor = None,
|
@@ -528,7 +567,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
528 |
output_attentions=output_attentions,
|
529 |
output_hidden_states=output_hidden_states,
|
530 |
return_dict=return_dict,
|
531 |
-
)
|
532 |
|
533 |
hidden_states = outputs[0]
|
534 |
logits = self.lm_head(hidden_states)
|
@@ -559,10 +598,19 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
559 |
)
|
560 |
|
561 |
def prepare_inputs_for_generation(
|
562 |
-
self,
|
563 |
-
|
|
|
|
|
|
|
|
|
|
|
564 |
if past_key_values:
|
565 |
input_ids = input_ids[:, -1:]
|
|
|
|
|
|
|
|
|
566 |
|
567 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
568 |
if inputs_embeds is not None and past_key_values is None:
|
@@ -571,21 +619,38 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
571 |
model_inputs = {"input_ids": input_ids}
|
572 |
|
573 |
model_inputs.update(
|
574 |
-
{
|
575 |
"past_key_values": past_key_values,
|
576 |
"use_cache": kwargs.get("use_cache"),
|
577 |
"attention_mask": attention_mask,
|
578 |
-
}
|
579 |
-
)
|
580 |
return model_inputs
|
581 |
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
)
|
588 |
-
|
589 |
|
590 |
def quantize(self, bits: int):
|
591 |
try:
|
@@ -594,7 +659,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
594 |
raise ImportError(
|
595 |
f"Needs QLinear to run quantize."
|
596 |
)
|
597 |
-
|
598 |
for layer in self.model.layers:
|
599 |
layer.self_attn.W_pack = QLinear(
|
600 |
bits=bits,
|
@@ -621,7 +686,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
621 |
weight=layer.mlp.up_proj.weight,
|
622 |
bias = None,
|
623 |
)
|
624 |
-
return self
|
625 |
|
626 |
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
|
627 |
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
|
|
300 |
if isinstance(module, BaichuanModel):
|
301 |
module.gradient_checkpointing = value
|
302 |
|
303 |
+
@staticmethod
|
304 |
+
def _convert_to_standard_cache(
|
305 |
+
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
306 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
307 |
+
"""
|
308 |
+
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
309 |
+
num_heads, ...]))
|
310 |
+
"""
|
311 |
+
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
312 |
+
num_heads = batch_size_times_num_heads // batch_size
|
313 |
+
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
314 |
+
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
315 |
+
return tuple(
|
316 |
+
(
|
317 |
+
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
318 |
+
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
319 |
+
)
|
320 |
+
for layer_past in past_key_value
|
321 |
+
)
|
322 |
+
|
323 |
+
@staticmethod
|
324 |
+
def _convert_to_baichuan_cache(
|
325 |
+
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
326 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
327 |
+
"""
|
328 |
+
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
329 |
+
"""
|
330 |
+
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
331 |
+
batch_size_times_num_heads = batch_size * num_heads
|
332 |
+
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
333 |
+
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
334 |
+
return tuple(
|
335 |
+
(
|
336 |
+
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
337 |
+
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
338 |
+
)
|
339 |
+
for layer_past in past_key_value
|
340 |
+
)
|
341 |
+
|
342 |
|
343 |
class BaichuanModel(BaichuanPreTrainedModel):
|
344 |
|
|
|
357 |
|
358 |
def get_input_embeddings(self):
|
359 |
return self.embed_tokens
|
360 |
+
|
361 |
def set_input_embeddings(self, value):
|
362 |
+
self.embed_tokens = value
|
363 |
|
364 |
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
365 |
return build_alibi_tensor(attention_mask, num_heads, dtype)
|
|
|
507 |
hidden_states=all_hidden_states,
|
508 |
attentions=all_self_attns,
|
509 |
)
|
510 |
+
|
511 |
|
512 |
class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
513 |
|
|
|
537 |
|
538 |
def get_decoder(self):
|
539 |
return self.model
|
540 |
+
|
541 |
def forward(
|
542 |
self,
|
543 |
input_ids: torch.LongTensor = None,
|
|
|
567 |
output_attentions=output_attentions,
|
568 |
output_hidden_states=output_hidden_states,
|
569 |
return_dict=return_dict,
|
570 |
+
)
|
571 |
|
572 |
hidden_states = outputs[0]
|
573 |
logits = self.lm_head(hidden_states)
|
|
|
598 |
)
|
599 |
|
600 |
def prepare_inputs_for_generation(
|
601 |
+
self,
|
602 |
+
input_ids: torch.LongTensor,
|
603 |
+
past_key_values: Optional[torch.Tensor] = None,
|
604 |
+
attention_mask: Optional[torch.Tensor] = None,
|
605 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
606 |
+
**kwargs
|
607 |
+
) -> dict:
|
608 |
if past_key_values:
|
609 |
input_ids = input_ids[:, -1:]
|
610 |
+
|
611 |
+
# the cache may be in the standard format (e.g. in contrastive search)
|
612 |
+
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
613 |
+
past_key_values = self._convert_to_baichuan_cache(past_key_values)
|
614 |
|
615 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
616 |
if inputs_embeds is not None and past_key_values is None:
|
|
|
619 |
model_inputs = {"input_ids": input_ids}
|
620 |
|
621 |
model_inputs.update(
|
622 |
+
{
|
623 |
"past_key_values": past_key_values,
|
624 |
"use_cache": kwargs.get("use_cache"),
|
625 |
"attention_mask": attention_mask,
|
626 |
+
}
|
627 |
+
)
|
628 |
return model_inputs
|
629 |
|
630 |
+
def _reorder_cache(
|
631 |
+
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
632 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
633 |
+
"""
|
634 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
635 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
636 |
+
beam_idx at every generation step.
|
637 |
+
|
638 |
+
Output shares the same memory storage as `past`.
|
639 |
+
"""
|
640 |
+
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
641 |
+
|
642 |
+
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
643 |
+
device_to_beam_idx = {
|
644 |
+
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
645 |
+
}
|
646 |
+
reordered_past = tuple(
|
647 |
+
(
|
648 |
+
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
649 |
+
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
650 |
+
)
|
651 |
+
for layer_past in standardized_past
|
652 |
)
|
653 |
+
return self._convert_to_baichuan_cache(reordered_past)
|
654 |
|
655 |
def quantize(self, bits: int):
|
656 |
try:
|
|
|
659 |
raise ImportError(
|
660 |
f"Needs QLinear to run quantize."
|
661 |
)
|
662 |
+
|
663 |
for layer in self.model.layers:
|
664 |
layer.self_attn.W_pack = QLinear(
|
665 |
bits=bits,
|
|
|
686 |
weight=layer.mlp.up_proj.weight,
|
687 |
bias = None,
|
688 |
)
|
689 |
+
return self
|
690 |
|
691 |
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
|
692 |
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|