moontidef commited on
Commit
9c63ce0
1 Parent(s): 807ba34

feat: add support for SequenceClassification

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_xlm_roberta.py +149 -34
config.json CHANGED
@@ -3,7 +3,8 @@
3
  "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
  "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
  "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
- "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM"
 
7
  },
8
  "attention_probs_dropout_prob": 0.1,
9
  "bos_token_id": 0,
 
3
  "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
  "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
  "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
+ "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM",
7
+ "AutoModelForSequenceClassification":"modeling_xlm_roberta.XLMRobertaForSequenceClassification"
8
  },
9
  "attention_probs_dropout_prob": 0.1,
10
  "bos_token_id": 0,
modeling_xlm_roberta.py CHANGED
@@ -18,10 +18,11 @@ import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
  import torch.utils.checkpoint
 
21
  from einops import rearrange
22
  from transformers import PretrainedConfig
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.modeling_outputs import MaskedLMOutput
25
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
26
 
27
  from transformers.models.bert.modeling_bert import (
@@ -429,7 +430,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
429
  for key, value in kwargs.items():
430
  if value is not None:
431
  logger.warning(
432
- 'Flash attention implementation does not support kwargs: %s',
433
  key,
434
  )
435
 
@@ -834,47 +835,47 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
834
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
835
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
836
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
837
- state_dict[
838
- f"bert.encoder.layers.{d}.attention.self.query.weight"
839
- ] = Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
840
- state_dict[
841
- f"bert.encoder.layers.{d}.attention.self.key.weight"
842
- ] = Wqkv_weights[
843
- Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
844
- ]
845
- state_dict[
846
- f"bert.encoder.layers.{d}.attention.self.value.weight"
847
- ] = Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
848
- state_dict[
849
- f"bert.encoder.layers.{d}.attention.self.query.bias"
850
- ] = Wqkv_biases[: Wqkv_biases.shape[0] // 3]
851
- state_dict[
852
- f"bert.encoder.layers.{d}.attention.self.key.bias"
853
- ] = Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
854
- state_dict[
855
- f"bert.encoder.layers.{d}.attention.self.value.bias"
856
- ] = Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
857
  else:
858
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
859
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
860
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
861
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
862
- state_dict[
863
- f"bert.encoder.layers.{d}.attention.self.query.weight"
864
- ] = Wq_weight
865
- state_dict[
866
- f"bert.encoder.layers.{d}.attention.self.key.weight"
867
- ] = Wkv_weights[: Wkv_weights.shape[0] // 2, :]
868
- state_dict[
869
- f"bert.encoder.layers.{d}.attention.self.value.weight"
870
- ] = Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
871
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
872
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
873
  : Wkv_biases.shape[0] // 2
874
  ]
875
- state_dict[
876
- f"bert.encoder.layers.{d}.attention.self.value.bias"
877
- ] = Wkv_biases[Wkv_biases.shape[0] // 2 :]
878
 
879
  def inv_key_mapping_ln(key):
880
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
@@ -946,3 +947,117 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
946
  )
947
 
948
  return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
  import torch.utils.checkpoint
21
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
22
  from einops import rearrange
23
  from transformers import PretrainedConfig
24
  from transformers.modeling_utils import PreTrainedModel
25
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
26
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
27
 
28
  from transformers.models.bert.modeling_bert import (
 
430
  for key, value in kwargs.items():
431
  if value is not None:
432
  logger.warning(
433
+ "Flash attention implementation does not support kwargs: %s",
434
  key,
435
  )
436
 
 
835
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
836
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
837
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
838
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
839
+ Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
840
+ )
841
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
842
+ Wqkv_weights[
843
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
844
+ ]
845
+ )
846
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
847
+ Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
848
+ )
849
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = (
850
+ Wqkv_biases[: Wqkv_biases.shape[0] // 3]
851
+ )
852
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = (
853
+ Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
854
+ )
855
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
856
+ Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
857
+ )
858
  else:
859
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
860
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
861
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
862
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
863
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
864
+ Wq_weight
865
+ )
866
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
867
+ Wkv_weights[: Wkv_weights.shape[0] // 2, :]
868
+ )
869
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
870
+ Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
871
+ )
872
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
873
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
874
  : Wkv_biases.shape[0] // 2
875
  ]
876
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
877
+ Wkv_biases[Wkv_biases.shape[0] // 2 :]
878
+ )
879
 
880
  def inv_key_mapping_ln(key):
881
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
 
947
  )
948
 
949
  return state_dict
950
+
951
+
952
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
953
+ class XLMRobertaClassificationHead(nn.Module):
954
+ """Head for sentence-level classification tasks."""
955
+
956
+ def __init__(self, config):
957
+ super().__init__()
958
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
959
+ classifier_dropout = (
960
+ config.classifier_dropout
961
+ if config.classifier_dropout is not None
962
+ else config.hidden_dropout_prob
963
+ )
964
+ self.dropout = nn.Dropout(classifier_dropout)
965
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
966
+
967
+ def forward(self, features, **kwargs):
968
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
969
+ x = self.dropout(x)
970
+ x = self.dense(x)
971
+ x = torch.tanh(x)
972
+ x = self.dropout(x)
973
+ x = self.out_proj(x)
974
+ return x
975
+
976
+
977
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
978
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
979
+ def __init__(self, config):
980
+ super().__init__(config)
981
+ self.num_labels = config.num_labels
982
+ self.config = config
983
+
984
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
985
+ self.classifier = XLMRobertaClassificationHead(config)
986
+
987
+ # Initialize weights and apply final processing
988
+ self.post_init()
989
+
990
+ def forward(
991
+ self,
992
+ input_ids: Optional[torch.LongTensor] = None,
993
+ attention_mask: Optional[torch.FloatTensor] = None,
994
+ token_type_ids: Optional[torch.LongTensor] = None,
995
+ position_ids: Optional[torch.LongTensor] = None,
996
+ head_mask: Optional[torch.FloatTensor] = None,
997
+ inputs_embeds: Optional[torch.FloatTensor] = None,
998
+ labels: Optional[torch.LongTensor] = None,
999
+ output_attentions: Optional[bool] = None,
1000
+ output_hidden_states: Optional[bool] = None,
1001
+ return_dict: Optional[bool] = None,
1002
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1003
+ r"""
1004
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1005
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1006
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1007
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1008
+ """
1009
+ return_dict = (
1010
+ return_dict if return_dict is not None else self.config.use_return_dict
1011
+ )
1012
+
1013
+ outputs = self.roberta(
1014
+ input_ids,
1015
+ attention_mask=attention_mask,
1016
+ token_type_ids=token_type_ids,
1017
+ position_ids=position_ids,
1018
+ head_mask=head_mask,
1019
+ inputs_embeds=inputs_embeds,
1020
+ output_attentions=output_attentions,
1021
+ output_hidden_states=output_hidden_states,
1022
+ return_dict=return_dict,
1023
+ )
1024
+ sequence_output = outputs[0]
1025
+ logits = self.classifier(sequence_output)
1026
+
1027
+ loss = None
1028
+ if labels is not None:
1029
+ # move labels to correct device to enable model parallelism
1030
+ labels = labels.to(logits.device)
1031
+ if self.config.problem_type is None:
1032
+ if self.num_labels == 1:
1033
+ self.config.problem_type = "regression"
1034
+ elif self.num_labels > 1 and (
1035
+ labels.dtype == torch.long or labels.dtype == torch.int
1036
+ ):
1037
+ self.config.problem_type = "single_label_classification"
1038
+ else:
1039
+ self.config.problem_type = "multi_label_classification"
1040
+
1041
+ if self.config.problem_type == "regression":
1042
+ loss_fct = MSELoss()
1043
+ if self.num_labels == 1:
1044
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1045
+ else:
1046
+ loss = loss_fct(logits, labels)
1047
+ elif self.config.problem_type == "single_label_classification":
1048
+ loss_fct = CrossEntropyLoss()
1049
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1050
+ elif self.config.problem_type == "multi_label_classification":
1051
+ loss_fct = BCEWithLogitsLoss()
1052
+ loss = loss_fct(logits, labels)
1053
+
1054
+ if not return_dict:
1055
+ output = (logits,) + outputs[2:]
1056
+ return ((loss,) + output) if loss is not None else output
1057
+
1058
+ return SequenceClassifierOutput(
1059
+ loss=loss,
1060
+ logits=logits,
1061
+ hidden_states=outputs.hidden_states,
1062
+ attentions=outputs.attentions,
1063
+ )