zxdu20 commited on
Commit
a2eaddf
1 Parent(s): 4ebf106

Sync with chatglm-6b

Browse files
configuration_chatglm.py CHANGED
@@ -72,6 +72,8 @@ class ChatGLMConfig(PretrainedConfig):
72
  position_encoding_2d=True,
73
  quantization_bit=0,
74
  quantization_embeddings=False,
 
 
75
  **kwargs
76
  ):
77
  self.num_layers = num_layers
@@ -86,8 +88,11 @@ class ChatGLMConfig(PretrainedConfig):
86
  self.eos_token_id = eos_token_id
87
  self.pad_token_id = pad_token_id
88
  self.position_encoding_2d = position_encoding_2d
89
- self.quantization_bit=quantization_bit
90
- self.quantization_embeddings=quantization_embeddings
 
 
 
91
  super().__init__(
92
  pad_token_id=pad_token_id,
93
  bos_token_id=bos_token_id,
 
72
  position_encoding_2d=True,
73
  quantization_bit=0,
74
  quantization_embeddings=False,
75
+ pre_seq_len=None,
76
+ prefix_projection=False,
77
  **kwargs
78
  ):
79
  self.num_layers = num_layers
 
88
  self.eos_token_id = eos_token_id
89
  self.pad_token_id = pad_token_id
90
  self.position_encoding_2d = position_encoding_2d
91
+ self.quantization_bit = quantization_bit
92
+ self.quantization_embeddings = quantization_embeddings
93
+ self.pre_seq_len = pre_seq_len
94
+ self.prefix_projection = prefix_projection
95
+
96
  super().__init__(
97
  pad_token_id=pad_token_id,
98
  bos_token_id=bos_token_id,
modeling_chatglm.py CHANGED
@@ -5,6 +5,7 @@ import copy
5
  import os
6
  import warnings
7
  import re
 
8
 
9
  import torch
10
  import torch.utils.checkpoint
@@ -12,7 +13,7 @@ import torch.nn.functional as F
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss, LayerNorm
14
  from torch.nn.utils import skip_init
15
- from typing import Optional, Tuple, Union, List, Callable
16
 
17
  from transformers.utils import (
18
  add_code_sample_docstrings,
@@ -27,16 +28,18 @@ from transformers.modeling_outputs import (
27
  from transformers.modeling_utils import PreTrainedModel
28
  from transformers.utils import logging
29
  from transformers.generation.logits_process import LogitsProcessor
30
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
31
 
32
  from .configuration_chatglm import ChatGLMConfig
33
 
34
 
35
  # flags required to enable jit fusion kernels
36
- torch._C._jit_set_profiling_mode(False)
37
- torch._C._jit_set_profiling_executor(False)
38
- torch._C._jit_override_can_fuse_on_cpu(True)
39
- torch._C._jit_override_can_fuse_on_gpu(True)
 
 
40
 
41
  logger = logging.get_logger(__name__)
42
 
@@ -131,6 +134,36 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
131
  return model
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  @torch.jit.script
135
  def gelu_impl(x):
136
  """OpenAI's gelu implementation."""
@@ -219,7 +252,7 @@ def attention_fn(
219
  use_cache=False,
220
  ):
221
  if layer_past is not None:
222
- past_key, past_value = layer_past
223
  key_layer = torch.cat((past_key, key_layer), dim=0)
224
  value_layer = torch.cat((past_value, value_layer), dim=0)
225
 
@@ -273,7 +306,7 @@ def attention_fn(
273
  if not (attention_mask == 0).all():
274
  # if auto-regressive, skip
275
  attention_scores.masked_fill_(attention_mask, -10000.0)
276
- dtype = attention_scores.type()
277
  attention_scores = attention_scores.float()
278
  attention_scores = attention_scores * query_key_layer_scaling_coeff
279
 
@@ -619,10 +652,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
619
  """
620
 
621
  is_parallelizable = False
622
- supports_gradient_checkpointing = False
623
  config_class = ChatGLMConfig
624
  base_model_prefix = "transformer"
625
- _no_split_modules = ["GLM6BBlock"]
626
 
627
  def __init__(self, *inputs, **kwargs):
628
  super().__init__(*inputs, **kwargs)
@@ -631,6 +664,43 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
631
  """Initialize the weights."""
632
  return
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
  CHATGLM_6B_START_DOCSTRING = r"""
636
  This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
@@ -727,12 +797,15 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
727
  self.inner_hidden_size = config.inner_hidden_size
728
  self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
729
  self.position_encoding_2d = config.position_encoding_2d
 
 
730
 
731
  self.word_embeddings = skip_init(
732
  torch.nn.Embedding,
733
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
734
  dtype=self.params_dtype
735
  )
 
736
 
737
  def get_layer(layer_id):
738
  return GLMBlock(
@@ -755,43 +828,38 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
755
  # Final layer norm before output.
756
  self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
757
 
 
 
 
 
 
 
 
 
 
 
 
758
  def get_input_embeddings(self):
759
  return self.word_embeddings
760
 
761
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
762
  self.word_embeddings = new_embeddings
763
 
764
- def get_masks(self, seq, device):
765
- context_length = seq.index(self.config.bos_token_id) + 1
766
-
767
- attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
768
- attention_mask.tril_()
769
- attention_mask[..., :context_length - 1] = 1
770
- attention_mask.unsqueeze_(1)
771
- attention_mask = (attention_mask < 0.5).bool()
772
-
773
- return attention_mask
774
-
775
- def get_position_ids(self, seq, mask_position, device, gmask=False):
776
- context_length = seq.index(self.config.bos_token_id) + 1
777
- if self.position_encoding_2d:
778
- seq_length = seq.index(self.config.bos_token_id)
779
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
780
- if not gmask:
781
- position_ids[seq_length:] = mask_position
782
- block_position_ids = torch.cat((
783
- torch.zeros(seq_length, dtype=torch.long, device=device),
784
- torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
785
- ))
786
- position_ids = torch.stack((position_ids, block_position_ids), dim=0)
787
- else:
788
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
789
- if not gmask:
790
- position_ids[context_length - 1:] = mask_position
791
-
792
- position_ids = position_ids.unsqueeze(0)
793
-
794
- return position_ids
795
 
796
  @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
797
  @add_code_sample_docstrings(
@@ -819,6 +887,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
819
  use_cache = use_cache if use_cache is not None else self.config.use_cache
820
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
821
 
 
 
 
 
 
 
 
822
  if input_ids is not None and inputs_embeds is not None:
823
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
824
  elif input_ids is not None:
@@ -828,31 +903,41 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
828
  else:
829
  raise ValueError("You have to specify either input_ids or inputs_embeds")
830
 
 
 
 
831
  if past_key_values is None:
832
- past_key_values = tuple([None] * len(self.layers))
833
- seq = input_ids[0].tolist()
 
 
 
834
 
835
  if attention_mask is None:
836
  attention_mask = self.get_masks(
837
- seq=seq,
838
  device=input_ids.device
839
  )
840
 
 
841
  if position_ids is None:
842
  MASK, gMASK = 150000, 150001
843
  mask_token = MASK if MASK in input_ids else gMASK
844
  use_gmask = False if MASK in input_ids else gMASK
845
 
846
- mask_position = seq.index(mask_token)
847
  position_ids = self.get_position_ids(
848
- seq=seq,
849
- mask_position=mask_position,
850
  device=input_ids.device,
851
  gmask=use_gmask
852
  )
853
 
854
- if inputs_embeds is None:
855
- inputs_embeds = self.word_embeddings(input_ids)
 
 
 
856
 
857
  # [seq_len, batch, hidden_size]
858
  hidden_states = inputs_embeds.transpose(0, 1)
@@ -861,11 +946,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
861
  all_self_attentions = () if output_attentions else None
862
  all_hidden_states = () if output_hidden_states else None
863
 
864
- seq_length_with_past = seq_length
865
- past_key_values_length = 0
866
- if past_key_values[0] is not None:
867
- past_key_values_length = past_key_values[0][0].shape[0]
868
- seq_length_with_past = seq_length_with_past + past_key_values_length
869
  if attention_mask is None:
870
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
871
 
@@ -876,16 +956,29 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
876
 
877
  if output_hidden_states:
878
  all_hidden_states = all_hidden_states + (hidden_states,)
879
-
880
- layer_ret = layer(
881
- hidden_states,
882
- position_ids=position_ids,
883
- attention_mask=attention_mask,
884
- layer_id=torch.tensor(i),
885
- layer_past=past_key_values[i],
886
- use_cache=use_cache,
887
- output_attentions=output_attentions
888
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
889
 
890
  hidden_states = layer_ret[0]
891
 
@@ -946,31 +1039,40 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
946
  def set_output_embeddings(self, new_embeddings):
947
  self.lm_head = new_embeddings
948
 
949
- def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
950
- attention_mask = torch.ones((1, context_length, context_length), device=device)
951
- attention_mask.tril_()
952
- attention_mask[..., :context_length - 1] = 1
953
- attention_mask.unsqueeze_(1)
954
- attention_mask = (attention_mask < 0.5).bool()
 
 
 
 
 
955
 
956
- if self.position_encoding_2d:
957
- seq_length = seq.index(self.config.bos_token_id)
958
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
959
- if not gmask:
960
- position_ids[seq_length:] = mask_position
961
- block_position_ids = torch.cat((
962
- torch.zeros(seq_length, dtype=torch.long, device=device),
963
- torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
964
- ))
965
- position_ids = torch.stack((position_ids, block_position_ids), dim=0)
966
- else:
967
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
968
- if not gmask:
969
- position_ids[context_length - 1:] = mask_position
970
 
971
- position_ids = position_ids.unsqueeze(0)
 
 
 
 
 
 
 
972
 
973
- return attention_mask, position_ids
974
 
975
  def prepare_inputs_for_generation(
976
  self,
@@ -978,27 +1080,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
978
  past: Optional[torch.Tensor] = None,
979
  past_key_values: Optional[torch.Tensor] = None,
980
  attention_mask: Optional[torch.Tensor] = None,
 
981
  **kwargs
982
  ) -> dict:
983
-
984
  MASK, gMASK = 150000, 150001
985
  mask_token = MASK if MASK in input_ids else gMASK
986
  use_gmask = False if MASK in input_ids else gMASK
987
- seq = input_ids[0].tolist()
988
- mask_position = seq.index(mask_token)
989
-
990
- if mask_token not in seq:
991
- raise ValueError("You have to add either [MASK] or [gMASK] in your input")
992
 
993
  # only last token for input_ids if past is not None
994
  if past is not None or past_key_values is not None:
995
- context_length = seq.index(self.config.bos_token_id)
996
  last_token = input_ids[:, -1].unsqueeze(-1)
997
- if self.position_encoding_2d:
998
- position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
999
- device=input_ids.device)
1000
  else:
1001
- position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
 
 
 
 
 
 
 
 
 
 
 
1002
 
1003
  if past is None:
1004
  past = past_key_values
@@ -1006,15 +1115,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1006
  "input_ids": last_token,
1007
  "past_key_values": past,
1008
  "position_ids": position_ids,
 
1009
  }
1010
  else:
1011
- attention_mask, position_ids = self.get_masks_and_position_ids(
1012
- seq=seq,
1013
- mask_position=mask_position,
1014
- context_length=len(seq),
1015
- device=input_ids.device,
1016
- gmask=use_gmask
1017
- )
 
 
 
 
 
 
 
 
1018
 
1019
  return {
1020
  "input_ids": input_ids,
@@ -1063,7 +1181,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1063
  shift_logits = lm_logits[..., :-1, :].contiguous()
1064
  shift_labels = labels[..., 1:].contiguous()
1065
  # Flatten the tokens
1066
- loss_fct = CrossEntropyLoss()
1067
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1068
 
1069
  lm_logits = lm_logits.to(hidden_states.dtype)
@@ -1132,10 +1250,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1132
  for i, (old_query, response) in enumerate(history):
1133
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1134
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1135
- input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1136
- input_ids = input_ids.to(self.device)
1137
- outputs = self.generate(**input_ids, **gen_kwargs)
1138
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1139
  response = tokenizer.decode(outputs)
1140
  response = self.process_response(response)
1141
  history = history + [(query, response)]
@@ -1158,10 +1276,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1158
  for i, (old_query, response) in enumerate(history):
1159
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1160
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1161
- input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1162
- input_ids = input_ids.to(self.device)
1163
- for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1164
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1165
  response = tokenizer.decode(outputs)
1166
  response = self.process_response(response)
1167
  new_history = history + [(query, response)]
 
5
  import os
6
  import warnings
7
  import re
8
+ import sys
9
 
10
  import torch
11
  import torch.utils.checkpoint
 
13
  from torch import nn
14
  from torch.nn import CrossEntropyLoss, LayerNorm
15
  from torch.nn.utils import skip_init
16
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
 
18
  from transformers.utils import (
19
  add_code_sample_docstrings,
 
28
  from transformers.modeling_utils import PreTrainedModel
29
  from transformers.utils import logging
30
  from transformers.generation.logits_process import LogitsProcessor
31
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
32
 
33
  from .configuration_chatglm import ChatGLMConfig
34
 
35
 
36
  # flags required to enable jit fusion kernels
37
+
38
+ if sys.platform != 'darwin':
39
+ torch._C._jit_set_profiling_mode(False)
40
+ torch._C._jit_set_profiling_executor(False)
41
+ torch._C._jit_override_can_fuse_on_cpu(True)
42
+ torch._C._jit_override_can_fuse_on_gpu(True)
43
 
44
  logger = logging.get_logger(__name__)
45
 
 
134
  return model
135
 
136
 
137
+ class PrefixEncoder(torch.nn.Module):
138
+ """
139
+ The torch.nn model to encode the prefix
140
+ Input shape: (batch-size, prefix-length)
141
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
142
+ """
143
+
144
+ def __init__(self, config):
145
+ super().__init__()
146
+ self.prefix_projection = config.prefix_projection
147
+ if self.prefix_projection:
148
+ # Use a two-layer MLP to encode the prefix
149
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
150
+ self.trans = torch.nn.Sequential(
151
+ torch.nn.Linear(config.hidden_size, config.hidden_size),
152
+ torch.nn.Tanh(),
153
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
154
+ )
155
+ else:
156
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
157
+
158
+ def forward(self, prefix: torch.Tensor):
159
+ if self.prefix_projection:
160
+ prefix_tokens = self.embedding(prefix)
161
+ past_key_values = self.trans(prefix_tokens)
162
+ else:
163
+ past_key_values = self.embedding(prefix)
164
+ return past_key_values
165
+
166
+
167
  @torch.jit.script
168
  def gelu_impl(x):
169
  """OpenAI's gelu implementation."""
 
252
  use_cache=False,
253
  ):
254
  if layer_past is not None:
255
+ past_key, past_value = layer_past[0], layer_past[1]
256
  key_layer = torch.cat((past_key, key_layer), dim=0)
257
  value_layer = torch.cat((past_value, value_layer), dim=0)
258
 
 
306
  if not (attention_mask == 0).all():
307
  # if auto-regressive, skip
308
  attention_scores.masked_fill_(attention_mask, -10000.0)
309
+ dtype = attention_scores.dtype
310
  attention_scores = attention_scores.float()
311
  attention_scores = attention_scores * query_key_layer_scaling_coeff
312
 
 
652
  """
653
 
654
  is_parallelizable = False
655
+ supports_gradient_checkpointing = True
656
  config_class = ChatGLMConfig
657
  base_model_prefix = "transformer"
658
+ _no_split_modules = ["GLMBlock"]
659
 
660
  def __init__(self, *inputs, **kwargs):
661
  super().__init__(*inputs, **kwargs)
 
664
  """Initialize the weights."""
665
  return
666
 
667
+ def get_masks(self, input_ids, device):
668
+ batch_size, seq_length = input_ids.shape
669
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
670
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
671
+ attention_mask.tril_()
672
+ for i, context_length in enumerate(context_lengths):
673
+ attention_mask[i, :, :context_length] = 1
674
+ attention_mask.unsqueeze_(1)
675
+ attention_mask = (attention_mask < 0.5).bool()
676
+
677
+ return attention_mask
678
+
679
+ def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
680
+ batch_size, seq_length = input_ids.shape
681
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
682
+ if self.position_encoding_2d:
683
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
684
+ for i, context_length in enumerate(context_lengths):
685
+ position_ids[i, context_length:] = mask_positions[i]
686
+ block_position_ids = [torch.cat((
687
+ torch.zeros(context_length, dtype=torch.long, device=device),
688
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
689
+ )) for context_length in context_lengths]
690
+ block_position_ids = torch.stack(block_position_ids, dim=0)
691
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
692
+ else:
693
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
694
+ if not gmask:
695
+ for i, context_length in enumerate(context_lengths):
696
+ position_ids[context_length:] = mask_positions[i]
697
+
698
+ return position_ids
699
+
700
+ def _set_gradient_checkpointing(self, module, value=False):
701
+ if isinstance(module, ChatGLMModel):
702
+ module.gradient_checkpointing = value
703
+
704
 
705
  CHATGLM_6B_START_DOCSTRING = r"""
706
  This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
 
797
  self.inner_hidden_size = config.inner_hidden_size
798
  self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
799
  self.position_encoding_2d = config.position_encoding_2d
800
+ self.pre_seq_len = config.pre_seq_len
801
+ self.prefix_projection = config.prefix_projection
802
 
803
  self.word_embeddings = skip_init(
804
  torch.nn.Embedding,
805
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
806
  dtype=self.params_dtype
807
  )
808
+ self.gradient_checkpointing = False
809
 
810
  def get_layer(layer_id):
811
  return GLMBlock(
 
828
  # Final layer norm before output.
829
  self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
830
 
831
+ if self.pre_seq_len is not None:
832
+ for param in self.parameters():
833
+ param.requires_grad = False
834
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
835
+ self.prefix_encoder = PrefixEncoder(config)
836
+ self.dropout = torch.nn.Dropout(0.1)
837
+
838
+ # total_params = sum(p.numel() for p in self.parameters())
839
+ # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
840
+ # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
841
+
842
  def get_input_embeddings(self):
843
  return self.word_embeddings
844
 
845
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
846
  self.word_embeddings = new_embeddings
847
 
848
+ def get_prompt(self, batch_size, device, dtype=torch.half):
849
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
850
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
851
+ past_key_values = past_key_values.view(
852
+ batch_size,
853
+ self.pre_seq_len,
854
+ self.num_layers * 2,
855
+ self.num_attention_heads,
856
+ self.hidden_size // self.num_attention_heads
857
+ )
858
+ # seq_len, b, nh, hidden_size
859
+ past_key_values = self.dropout(past_key_values)
860
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
861
+ # past_key_values = [(v[0], v[1]) for v in past_key_values]
862
+ return past_key_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
 
864
  @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
865
  @add_code_sample_docstrings(
 
887
  use_cache = use_cache if use_cache is not None else self.config.use_cache
888
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
889
 
890
+ if self.gradient_checkpointing and self.training:
891
+ if use_cache:
892
+ logger.warning_once(
893
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
894
+ )
895
+ use_cache = False
896
+
897
  if input_ids is not None and inputs_embeds is not None:
898
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
899
  elif input_ids is not None:
 
903
  else:
904
  raise ValueError("You have to specify either input_ids or inputs_embeds")
905
 
906
+ if inputs_embeds is None:
907
+ inputs_embeds = self.word_embeddings(input_ids)
908
+
909
  if past_key_values is None:
910
+ if self.pre_seq_len is not None:
911
+ past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
912
+ dtype=inputs_embeds.dtype)
913
+ else:
914
+ past_key_values = tuple([None] * len(self.layers))
915
 
916
  if attention_mask is None:
917
  attention_mask = self.get_masks(
918
+ input_ids,
919
  device=input_ids.device
920
  )
921
 
922
+
923
  if position_ids is None:
924
  MASK, gMASK = 150000, 150001
925
  mask_token = MASK if MASK in input_ids else gMASK
926
  use_gmask = False if MASK in input_ids else gMASK
927
 
928
+ mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
929
  position_ids = self.get_position_ids(
930
+ input_ids,
931
+ mask_positions=mask_positions,
932
  device=input_ids.device,
933
  gmask=use_gmask
934
  )
935
 
936
+ if self.pre_seq_len is not None and attention_mask is not None:
937
+ prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
938
+ attention_mask.device)
939
+ prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
940
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
941
 
942
  # [seq_len, batch, hidden_size]
943
  hidden_states = inputs_embeds.transpose(0, 1)
 
946
  all_self_attentions = () if output_attentions else None
947
  all_hidden_states = () if output_hidden_states else None
948
 
 
 
 
 
 
949
  if attention_mask is None:
950
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
951
 
 
956
 
957
  if output_hidden_states:
958
  all_hidden_states = all_hidden_states + (hidden_states,)
959
+ layer_past = past_key_values[i]
960
+
961
+ if self.gradient_checkpointing and self.training:
962
+ layer_ret = torch.utils.checkpoint.checkpoint(
963
+ layer,
964
+ hidden_states,
965
+ position_ids,
966
+ attention_mask,
967
+ torch.tensor(i),
968
+ layer_past,
969
+ use_cache,
970
+ output_attentions
971
+ )
972
+ else:
973
+ layer_ret = layer(
974
+ hidden_states,
975
+ position_ids=position_ids,
976
+ attention_mask=attention_mask,
977
+ layer_id=torch.tensor(i),
978
+ layer_past=layer_past,
979
+ use_cache=use_cache,
980
+ output_attentions=output_attentions
981
+ )
982
 
983
  hidden_states = layer_ret[0]
984
 
 
1039
  def set_output_embeddings(self, new_embeddings):
1040
  self.lm_head = new_embeddings
1041
 
1042
+ def _update_model_kwargs_for_generation(
1043
+ self,
1044
+ outputs: ModelOutput,
1045
+ model_kwargs: Dict[str, Any],
1046
+ is_encoder_decoder: bool = False,
1047
+ standardize_cache_format: bool = False,
1048
+ ) -> Dict[str, Any]:
1049
+ # update past_key_values
1050
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1051
+ outputs, standardize_cache_format=standardize_cache_format
1052
+ )
1053
 
1054
+ # update attention mask
1055
+ if "attention_mask" in model_kwargs:
1056
+ attention_mask = model_kwargs["attention_mask"]
1057
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
1058
+ attention_mask = torch.cat(
1059
+ [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
1060
+ new_attention_mask = attention_mask[:, :, -1:].clone()
1061
+ new_attention_mask[..., -1] = False
1062
+ model_kwargs["attention_mask"] = torch.cat(
1063
+ [attention_mask, new_attention_mask], dim=2
1064
+ )
 
 
 
1065
 
1066
+ # update position ids
1067
+ if "position_ids" in model_kwargs:
1068
+ position_ids = model_kwargs["position_ids"]
1069
+ new_position_id = position_ids[..., -1:].clone()
1070
+ new_position_id[:, 1, :] += 1
1071
+ model_kwargs["position_ids"] = torch.cat(
1072
+ [position_ids, new_position_id], dim=-1
1073
+ )
1074
 
1075
+ return model_kwargs
1076
 
1077
  def prepare_inputs_for_generation(
1078
  self,
 
1080
  past: Optional[torch.Tensor] = None,
1081
  past_key_values: Optional[torch.Tensor] = None,
1082
  attention_mask: Optional[torch.Tensor] = None,
1083
+ position_ids: Optional[torch.Tensor] = None,
1084
  **kwargs
1085
  ) -> dict:
1086
+ batch_size, seq_length = input_ids.shape
1087
  MASK, gMASK = 150000, 150001
1088
  mask_token = MASK if MASK in input_ids else gMASK
1089
  use_gmask = False if MASK in input_ids else gMASK
1090
+ seqs = input_ids.tolist()
1091
+ mask_positions = [seq.index(mask_token) for seq in seqs]
 
 
 
1092
 
1093
  # only last token for input_ids if past is not None
1094
  if past is not None or past_key_values is not None:
 
1095
  last_token = input_ids[:, -1].unsqueeze(-1)
1096
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
1097
+ attention_mask = attention_mask[:, :, -1:]
 
1098
  else:
1099
+ attention_mask = None
1100
+ if position_ids is not None:
1101
+ position_ids = position_ids[..., -1:]
1102
+ else:
1103
+ context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
1104
+ if self.position_encoding_2d:
1105
+ position_ids = torch.tensor(
1106
+ [[mask_position, seq_length - context_length] for mask_position, context_length in
1107
+ zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
1108
+ else:
1109
+ position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
1110
+ device=input_ids.device).unsqueeze(-1)
1111
 
1112
  if past is None:
1113
  past = past_key_values
 
1115
  "input_ids": last_token,
1116
  "past_key_values": past,
1117
  "position_ids": position_ids,
1118
+ "attention_mask": attention_mask
1119
  }
1120
  else:
1121
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
1122
+ logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
1123
+ attention_mask = None
1124
+ if attention_mask is None:
1125
+ attention_mask = self.get_masks(
1126
+ input_ids,
1127
+ device=input_ids.device
1128
+ )
1129
+ if position_ids is None:
1130
+ position_ids = self.get_position_ids(
1131
+ input_ids,
1132
+ device=input_ids.device,
1133
+ mask_positions=mask_positions,
1134
+ gmask=use_gmask
1135
+ )
1136
 
1137
  return {
1138
  "input_ids": input_ids,
 
1181
  shift_logits = lm_logits[..., :-1, :].contiguous()
1182
  shift_labels = labels[..., 1:].contiguous()
1183
  # Flatten the tokens
1184
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1185
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1186
 
1187
  lm_logits = lm_logits.to(hidden_states.dtype)
 
1250
  for i, (old_query, response) in enumerate(history):
1251
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1252
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1253
+ inputs = tokenizer([prompt], return_tensors="pt")
1254
+ inputs = inputs.to(self.device)
1255
+ outputs = self.generate(**inputs, **gen_kwargs)
1256
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1257
  response = tokenizer.decode(outputs)
1258
  response = self.process_response(response)
1259
  history = history + [(query, response)]
 
1276
  for i, (old_query, response) in enumerate(history):
1277
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1278
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1279
+ inputs = tokenizer([prompt], return_tensors="pt")
1280
+ inputs = inputs.to(self.device)
1281
+ for outputs in self.stream_generate(**inputs, **gen_kwargs):
1282
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1283
  response = tokenizer.decode(outputs)
1284
  response = self.process_response(response)
1285
  new_history = history + [(query, response)]
quantization.py CHANGED
@@ -7,10 +7,13 @@ import bz2
7
  import torch
8
  import base64
9
  import ctypes
 
10
 
11
  from typing import List
12
  from functools import partial
13
 
 
 
14
  try:
15
  from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
16
 
@@ -37,18 +40,18 @@ try:
37
  )
38
  except Exception as exception:
39
  kernels = None
40
- print("Failed to load cpm_kernels:", exception)
41
 
42
 
43
  class W8A16Linear(torch.autograd.Function):
44
  @staticmethod
45
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
46
  ctx.inp_shape = inp.size()
47
- ctx.weight_shape = quant_w.size()
48
  ctx.weight_bit_width = weight_bit_width
49
  out_features = quant_w.size(0)
50
  inp = inp.contiguous().view(-1, inp.size(-1))
51
  weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
 
52
  output = inp.mm(weight.t())
53
  ctx.save_for_backward(inp, quant_w, scale_w)
54
  return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@@ -60,18 +63,18 @@ class W8A16Linear(torch.autograd.Function):
60
  grad_output = grad_output.contiguous().view(-1, weight.size(0))
61
  grad_input = grad_output.mm(weight)
62
  grad_weight = grad_output.t().mm(inp)
63
- return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
64
 
65
 
66
  class W8A16LinearCPU(torch.autograd.Function):
67
  @staticmethod
68
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
69
  ctx.inp_shape = inp.size()
70
- ctx.weight_shape = quant_w.size()
71
  ctx.weight_bit_width = weight_bit_width
72
  out_features = quant_w.size(0)
73
  inp = inp.contiguous().view(-1, inp.size(-1))
74
  weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
 
75
  output = inp.mm(weight.t())
76
  ctx.save_for_backward(inp, quant_w, scale_w)
77
  return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@@ -83,7 +86,7 @@ class W8A16LinearCPU(torch.autograd.Function):
83
  grad_output = grad_output.contiguous().view(-1, weight.size(0))
84
  grad_input = grad_output.mm(weight)
85
  grad_weight = grad_output.t().mm(inp)
86
- return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
87
 
88
 
89
  default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
@@ -168,7 +171,7 @@ class CPUKernel:
168
  print("Load kernel :", kernel_file)
169
  else:
170
  print("Failed to load kernel.")
171
-
172
  if compile_parallel_kernel:
173
  if parallel_num is None:
174
  parallel_num = max(os.cpu_count() // 2, 1)
@@ -176,7 +179,7 @@ class CPUKernel:
176
  if parallel_num < 4:
177
  print("Parallel kernel is not recommended when parallel num < 4.")
178
  self.SetNumThreads(parallel_num)
179
-
180
  self.parallel_num = parallel_num
181
 
182
 
@@ -284,10 +287,10 @@ def extract_weight_to_float(weight: torch.Tensor, scale_list: torch.Tensor, sour
284
  class CacheTensor():
285
  def __init__(self, *args, **kwargs):
286
  self.tensor = torch.empty(*args, **kwargs)
287
-
288
  def to(self, *args, **kwargs):
289
  self.tensor = self.tensor.to(*args, **kwargs)
290
-
291
  def data_ptr(self):
292
  return self.tensor.data_ptr()
293
 
@@ -393,7 +396,7 @@ def load_cpu_kernel(**kwargs):
393
 
394
  def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
395
  """Replace fp16 linear with quantized linear"""
396
-
397
  query_key_value_quantization_cache = None
398
  dense_quantization_cache = None
399
  dense_h_to_4h_quantization_cache = None
 
7
  import torch
8
  import base64
9
  import ctypes
10
+ from transformers.utils import logging
11
 
12
  from typing import List
13
  from functools import partial
14
 
15
+ logger = logging.get_logger(__name__)
16
+
17
  try:
18
  from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
19
 
 
40
  )
41
  except Exception as exception:
42
  kernels = None
43
+ logger.warning("Failed to load cpm_kernels:", exception)
44
 
45
 
46
  class W8A16Linear(torch.autograd.Function):
47
  @staticmethod
48
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
49
  ctx.inp_shape = inp.size()
 
50
  ctx.weight_bit_width = weight_bit_width
51
  out_features = quant_w.size(0)
52
  inp = inp.contiguous().view(-1, inp.size(-1))
53
  weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
54
+ ctx.weight_shape = weight.size()
55
  output = inp.mm(weight.t())
56
  ctx.save_for_backward(inp, quant_w, scale_w)
57
  return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
 
63
  grad_output = grad_output.contiguous().view(-1, weight.size(0))
64
  grad_input = grad_output.mm(weight)
65
  grad_weight = grad_output.t().mm(inp)
66
+ return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
67
 
68
 
69
  class W8A16LinearCPU(torch.autograd.Function):
70
  @staticmethod
71
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
72
  ctx.inp_shape = inp.size()
 
73
  ctx.weight_bit_width = weight_bit_width
74
  out_features = quant_w.size(0)
75
  inp = inp.contiguous().view(-1, inp.size(-1))
76
  weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
77
+ ctx.weight_shape = weight.size()
78
  output = inp.mm(weight.t())
79
  ctx.save_for_backward(inp, quant_w, scale_w)
80
  return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
 
86
  grad_output = grad_output.contiguous().view(-1, weight.size(0))
87
  grad_input = grad_output.mm(weight)
88
  grad_weight = grad_output.t().mm(inp)
89
+ return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
90
 
91
 
92
  default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
 
171
  print("Load kernel :", kernel_file)
172
  else:
173
  print("Failed to load kernel.")
174
+
175
  if compile_parallel_kernel:
176
  if parallel_num is None:
177
  parallel_num = max(os.cpu_count() // 2, 1)
 
179
  if parallel_num < 4:
180
  print("Parallel kernel is not recommended when parallel num < 4.")
181
  self.SetNumThreads(parallel_num)
182
+
183
  self.parallel_num = parallel_num
184
 
185
 
 
287
  class CacheTensor():
288
  def __init__(self, *args, **kwargs):
289
  self.tensor = torch.empty(*args, **kwargs)
290
+
291
  def to(self, *args, **kwargs):
292
  self.tensor = self.tensor.to(*args, **kwargs)
293
+
294
  def data_ptr(self):
295
  return self.tensor.data_ptr()
296
 
 
396
 
397
  def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs):
398
  """Replace fp16 linear with quantized linear"""
399
+
400
  query_key_value_quantization_cache = None
401
  dense_quantization_cache = None
402
  dense_h_to_4h_quantization_cache = None
tokenization_chatglm.py CHANGED
@@ -1,17 +1,14 @@
1
  """Tokenization classes for ChatGLM."""
2
- import sys
3
- import unicodedata
4
  from typing import List, Optional, Union
5
- from functools import lru_cache
6
  import os
7
- import collections
8
- import re
9
 
10
  from transformers.tokenization_utils import PreTrainedTokenizer
11
  from icetk.text_tokenizer import TextTokenizer
12
- from icetk.utils import auto_create
13
  import icetk.sentencepiece_model_pb2 as sp_model
14
- from transformers.utils import logging
 
 
 
15
 
16
  logger = logging.get_logger(__name__)
17
 
@@ -180,7 +177,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
180
 
181
  vocab_files_names = {"vocab_file": "ice_text.model"}
182
  max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
183
- model_input_names = ["input_ids"]
184
 
185
  def __init__(
186
  self,
@@ -210,7 +207,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
210
  self.eos_token = eos_token
211
  self.eop_token = eop_token
212
  self.mask_token = mask_token
213
- self.gMASK_token = gmask_token
214
 
215
  self.sp_tokenizer = SPTokenizer(vocab_file)
216
 
@@ -331,10 +328,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
331
  Returns:
332
  `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
333
  """
334
- if token_ids_1 is not None:
335
- token_ids_0 += token_ids_1
336
  mask_ids = self.sp_tokenizer[self.mask_token]
337
- gmask_ids = self.sp_tokenizer[self.gMASK_token]
 
338
  if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
339
  token_ids_0 += [gmask_ids]
340
 
@@ -343,4 +339,101 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
343
 
344
  token_ids_0 += [self.sp_tokenizer[self.bos_token]]
345
 
 
 
 
 
 
346
  return token_ids_0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Tokenization classes for ChatGLM."""
 
 
2
  from typing import List, Optional, Union
 
3
  import os
 
 
4
 
5
  from transformers.tokenization_utils import PreTrainedTokenizer
6
  from icetk.text_tokenizer import TextTokenizer
 
7
  import icetk.sentencepiece_model_pb2 as sp_model
8
+ from transformers.utils import logging, PaddingStrategy
9
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
10
+ from typing import Dict
11
+ import numpy as np
12
 
13
  logger = logging.get_logger(__name__)
14
 
 
177
 
178
  vocab_files_names = {"vocab_file": "ice_text.model"}
179
  max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
180
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
181
 
182
  def __init__(
183
  self,
 
207
  self.eos_token = eos_token
208
  self.eop_token = eop_token
209
  self.mask_token = mask_token
210
+ self.gmask_token = gmask_token
211
 
212
  self.sp_tokenizer = SPTokenizer(vocab_file)
213
 
 
328
  Returns:
329
  `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
330
  """
 
 
331
  mask_ids = self.sp_tokenizer[self.mask_token]
332
+ gmask_ids = self.sp_tokenizer[self.gmask_token]
333
+ eop_id = self.sp_tokenizer[self.eop_token]
334
  if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
335
  token_ids_0 += [gmask_ids]
336
 
 
339
 
340
  token_ids_0 += [self.sp_tokenizer[self.bos_token]]
341
 
342
+ if token_ids_1 is not None:
343
+ if not token_ids_1 or token_ids_1[-1] != eop_id:
344
+ token_ids_1 += [eop_id]
345
+ token_ids_0 += token_ids_1
346
+
347
  return token_ids_0
348
+
349
+ def _pad(
350
+ self,
351
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
352
+ max_length: Optional[int] = None,
353
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
354
+ pad_to_multiple_of: Optional[int] = None,
355
+ return_attention_mask: Optional[bool] = None,
356
+ ) -> dict:
357
+ """
358
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
359
+
360
+ Args:
361
+ encoded_inputs:
362
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
363
+ max_length: maximum length of the returned list and optionally padding length (see below).
364
+ Will truncate by taking into account the special tokens.
365
+ padding_strategy: PaddingStrategy to use for padding.
366
+
367
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
368
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
369
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
370
+ The tokenizer padding sides are defined in self.padding_side:
371
+
372
+ - 'left': pads on the left of the sequences
373
+ - 'right': pads on the right of the sequences
374
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
375
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
376
+ `>= 7.5` (Volta).
377
+ return_attention_mask:
378
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
379
+ """
380
+ # Load from model defaults
381
+ bos_token_id = self.sp_tokenizer[self.bos_token]
382
+ mask_token_id = self.sp_tokenizer[self.mask_token]
383
+ gmask_token_id = self.sp_tokenizer[self.gmask_token]
384
+ assert self.padding_side == "left"
385
+
386
+ required_input = encoded_inputs[self.model_input_names[0]]
387
+ seq_length = len(required_input)
388
+
389
+ if padding_strategy == PaddingStrategy.LONGEST:
390
+ max_length = len(required_input)
391
+
392
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
393
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
394
+
395
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
396
+
397
+ # Initialize attention mask if not present.
398
+ if max_length is not None:
399
+ if "attention_mask" not in encoded_inputs:
400
+ if bos_token_id in required_input:
401
+ context_length = required_input.index(bos_token_id)
402
+ else:
403
+ context_length = seq_length
404
+ attention_mask = np.ones((1, seq_length, seq_length))
405
+ attention_mask = np.tril(attention_mask)
406
+ attention_mask[:, :, :context_length] = 1
407
+ attention_mask = np.bool_(attention_mask < 0.5)
408
+ encoded_inputs["attention_mask"] = attention_mask
409
+
410
+ if "position_ids" not in encoded_inputs:
411
+ position_ids = np.arange(seq_length, dtype=np.int64)
412
+ mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
413
+ if mask_token in required_input:
414
+ mask_position = required_input.index(mask_token)
415
+ position_ids[context_length:] = mask_position
416
+ block_position_ids = np.concatenate(
417
+ [np.zeros(context_length, dtype=np.int64),
418
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
419
+ encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
420
+
421
+ if needs_to_be_padded:
422
+ difference = max_length - len(required_input)
423
+
424
+ if "attention_mask" in encoded_inputs:
425
+ encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
426
+ pad_width=[(0, 0), (difference, 0), (difference, 0)],
427
+ mode='constant', constant_values=True)
428
+ if "token_type_ids" in encoded_inputs:
429
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
430
+ "token_type_ids"
431
+ ]
432
+ if "special_tokens_mask" in encoded_inputs:
433
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
434
+ if "position_ids" in encoded_inputs:
435
+ encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
436
+ pad_width=[(0, 0), (difference, 0)])
437
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
438
+
439
+ return encoded_inputs