feat(modeling_stablelm_epoch.py): add support for AutoModelForSequenceClassification

#3
Files changed (2) hide show
  1. config.json +4 -2
  2. modeling_stablelm_epoch.py +109 -2
config.json CHANGED
@@ -1,10 +1,12 @@
1
  {
2
  "architectures": [
3
- "StableLMEpochForCausalLM"
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_stablelm_epoch.StableLMEpochConfig",
7
- "AutoModelForCausalLM": "modeling_stablelm_epoch.StableLMEpochForCausalLM"
 
8
  },
9
  "bos_token_id": 100257,
10
  "eos_token_id": 100257,
 
1
  {
2
  "architectures": [
3
+ "StableLMEpochForCausalLM",
4
+ "StableLMEpochForSequenceClassification"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "configuration_stablelm_epoch.StableLMEpochConfig",
8
+ "AutoModelForCausalLM": "modeling_stablelm_epoch.StableLMEpochForCausalLM",
9
+ "AutoModelForSequenceClassification": "modeling_stablelm_epoch.StableLMEpochForSequenceClassification"
10
  },
11
  "bos_token_id": 100257,
12
  "eos_token_id": 100257,
modeling_stablelm_epoch.py CHANGED
@@ -17,7 +17,7 @@
17
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
18
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
19
  """ PyTorch StableLM Epoch model. """
20
- from typing import Optional, Tuple, Union
21
  import math
22
  import warnings
23
 
@@ -25,12 +25,13 @@ import torch
25
  import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
28
- from torch.nn import CrossEntropyLoss
29
 
30
  from transformers.cache_utils import Cache
31
  from transformers.modeling_outputs import (
32
  BaseModelOutputWithPast,
33
  CausalLMOutputWithPast,
 
34
  )
35
  from transformers.modeling_utils import PreTrainedModel
36
  from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10
@@ -913,5 +914,111 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
913
  return reordered_past
914
 
915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
916
  StableLMEpochConfig.register_for_auto_class()
917
  StableLMEpochForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
17
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
18
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
19
  """ PyTorch StableLM Epoch model. """
20
+ from typing import Optional, Tuple, Union, List
21
  import math
22
  import warnings
23
 
 
25
  import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
28
+ from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss
29
 
30
  from transformers.cache_utils import Cache
31
  from transformers.modeling_outputs import (
32
  BaseModelOutputWithPast,
33
  CausalLMOutputWithPast,
34
+ SequenceClassifierOutputWithPast,
35
  )
36
  from transformers.modeling_utils import PreTrainedModel
37
  from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10
 
914
  return reordered_past
915
 
916
 
917
+ class StableLMEpochForSequenceClassification(StableLMEpochPreTrainedModel):
918
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
919
+
920
+ def __init__(self, config):
921
+ super().__init__(config)
922
+ self.num_labels = config.num_labels
923
+ self.model = StableLMEpochModel(config)
924
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
925
+
926
+ # Initialize weights and apply final processing
927
+ self.post_init()
928
+
929
+ def get_input_embeddings(self):
930
+ return self.model.embed_tokens
931
+
932
+ def set_input_embeddings(self, value):
933
+ self.model.embed_tokens = value
934
+
935
+ def forward(
936
+ self,
937
+ input_ids: torch.LongTensor = None,
938
+ attention_mask: Optional[torch.Tensor] = None,
939
+ position_ids: Optional[torch.LongTensor] = None,
940
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
941
+ inputs_embeds: Optional[torch.FloatTensor] = None,
942
+ labels: Optional[torch.LongTensor] = None,
943
+ use_cache: Optional[bool] = None,
944
+ output_attentions: Optional[bool] = None,
945
+ output_hidden_states: Optional[bool] = None,
946
+ return_dict: Optional[bool] = None,
947
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
948
+ r"""
949
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
950
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
951
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
952
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
953
+ """
954
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
955
+
956
+ transformer_outputs = self.model(
957
+ input_ids,
958
+ attention_mask=attention_mask,
959
+ position_ids=position_ids,
960
+ past_key_values=past_key_values,
961
+ inputs_embeds=inputs_embeds,
962
+ use_cache=use_cache,
963
+ output_attentions=output_attentions,
964
+ output_hidden_states=output_hidden_states,
965
+ return_dict=return_dict,
966
+ )
967
+ hidden_states = transformer_outputs[0]
968
+ logits = self.score(hidden_states)
969
+
970
+ if input_ids is not None:
971
+ batch_size = input_ids.shape[0]
972
+ else:
973
+ batch_size = inputs_embeds.shape[0]
974
+
975
+ if self.config.pad_token_id is None and batch_size != 1:
976
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
977
+ if self.config.pad_token_id is None:
978
+ sequence_lengths = -1
979
+ else:
980
+ if input_ids is not None:
981
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
982
+ else:
983
+ sequence_lengths = -1
984
+
985
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
986
+
987
+ loss = None
988
+ if labels is not None:
989
+ labels = labels.to(logits.device)
990
+ if self.config.problem_type is None:
991
+ if self.num_labels == 1:
992
+ self.config.problem_type = "regression"
993
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
994
+ self.config.problem_type = "single_label_classification"
995
+ else:
996
+ self.config.problem_type = "multi_label_classification"
997
+
998
+ if self.config.problem_type == "regression":
999
+ loss_fct = MSELoss()
1000
+ if self.num_labels == 1:
1001
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1002
+ else:
1003
+ loss = loss_fct(pooled_logits, labels)
1004
+ elif self.config.problem_type == "single_label_classification":
1005
+ loss_fct = CrossEntropyLoss()
1006
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1007
+ elif self.config.problem_type == "multi_label_classification":
1008
+ loss_fct = BCEWithLogitsLoss()
1009
+ loss = loss_fct(pooled_logits, labels)
1010
+ if not return_dict:
1011
+ output = (pooled_logits,) + transformer_outputs[1:]
1012
+ return ((loss,) + output) if loss is not None else output
1013
+
1014
+ return SequenceClassifierOutputWithPast(
1015
+ loss=loss,
1016
+ logits=pooled_logits,
1017
+ past_key_values=transformer_outputs.past_key_values,
1018
+ hidden_states=transformer_outputs.hidden_states,
1019
+ attentions=transformer_outputs.attentions,
1020
+ )
1021
+
1022
  StableLMEpochConfig.register_for_auto_class()
1023
  StableLMEpochForCausalLM.register_for_auto_class("AutoModelForCausalLM")
1024
+ StableLMEpochForSequenceClassification.register_for_auto_class("AutoModelForSequenceClassification")