bwang0911 commited on
Commit
64c81c6
1 Parent(s): 8542ad8

Update modeling_xlm_roberta.py (#12)

Browse files

- Update modeling_xlm_roberta.py (6a473f18306789a58285f781c0b6ec6f4df03fdb)

Files changed (1) hide show
  1. modeling_xlm_roberta.py +4 -3
modeling_xlm_roberta.py CHANGED
@@ -61,7 +61,7 @@ except ImportError:
61
  try:
62
  from flash_attn.losses.cross_entropy import CrossEntropyLoss
63
  except ImportError:
64
- CrossEntropyLoss = None
65
 
66
  try:
67
  from tqdm.autonotebook import trange
@@ -1168,14 +1168,15 @@ class XLMRobertaClassificationHead(nn.Module):
1168
 
1169
  def __init__(self, config):
1170
  super().__init__()
1171
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
1172
  classifier_dropout = (
1173
  config.classifier_dropout
1174
  if config.classifier_dropout is not None
1175
  else config.hidden_dropout_prob
1176
  )
1177
  self.dropout = nn.Dropout(classifier_dropout)
1178
- self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1179
 
1180
  def forward(self, features, **kwargs):
1181
  x = features[:, 0, :] # take <s> token (equiv. to [CLS])
 
61
  try:
62
  from flash_attn.losses.cross_entropy import CrossEntropyLoss
63
  except ImportError:
64
+ CrossEntropyLoss = torch.nn.CrossEntropyLoss
65
 
66
  try:
67
  from tqdm.autonotebook import trange
 
1168
 
1169
  def __init__(self, config):
1170
  super().__init__()
1171
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
1172
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
1173
  classifier_dropout = (
1174
  config.classifier_dropout
1175
  if config.classifier_dropout is not None
1176
  else config.hidden_dropout_prob
1177
  )
1178
  self.dropout = nn.Dropout(classifier_dropout)
1179
+ self.out_proj = linear_cls(config.hidden_size, config.num_labels)
1180
 
1181
  def forward(self, features, **kwargs):
1182
  x = features[:, 0, :] # take <s> token (equiv. to [CLS])