Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from torch.nn import CrossEntropyLoss | |
from transformers import ( | |
BertForSequenceClassification as SeqClassification, | |
BertPreTrainedModel, | |
BertModel, | |
BertConfig, | |
) | |
from .modeling_outputs import ( | |
QuestionAnsweringModelOutput, | |
QuestionAnsweringNaModelOutput, | |
) | |
class BertForSequenceClassification(SeqClassification): | |
model_type = "bert" | |
class BertForQuestionAnsweringAVPool(BertPreTrainedModel): | |
_keys_to_ignore_on_load_unexpected = [r"pooler"] | |
model_type = "bert" | |
def __init__(self, config): | |
super(BertForQuestionAnsweringAVPool, self).__init__(config) | |
self.num_labels = config.num_labels | |
self.bert = BertModel(config) | |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) | |
self.has_ans = nn.Sequential( | |
nn.Dropout(p=config.hidden_dropout_prob), | |
nn.Linear(config.hidden_size, 2) | |
) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
start_positions=None, | |
end_positions=None, | |
is_impossibles=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
sequence_output = outputs[0] | |
logits = self.qa_outputs(sequence_output) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1).contiguous() | |
end_logits = end_logits.squeeze(-1).contiguous() | |
first_word = sequence_output[:, 0, :] | |
has_logits = self.has_ans(first_word) | |
total_loss = None | |
if ( | |
start_positions is not None and | |
end_positions is not None and | |
is_impossibles is not None | |
): | |
# If we are on multi-GPU, split add a dimension | |
if len(start_positions.size()) > 1: | |
start_positions = start_positions.squeeze(-1) | |
if len(end_positions.size()) > 1: | |
end_positions = end_positions.squeeze(-1) | |
if len(is_impossibles.size()) > 1: | |
is_impossibles = is_impossibles.squeeze(-1) | |
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |
ignored_index = start_logits.size(1) | |
start_positions.clamp_(0, ignored_index) | |
end_positions.clamp_(0, ignored_index) | |
is_impossibles.clamp_(0, ignored_index) | |
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | |
start_loss = loss_fct(start_logits, start_positions) | |
end_loss = loss_fct(end_logits, end_positions) | |
span_loss = start_loss + end_loss | |
# Internal Front Verification (I-FV) | |
# alpha1 == 1.0, alpha2 == 0.5 | |
choice_loss = loss_fct(has_logits, is_impossibles.long()) | |
total_loss = 1.0 * span_loss + 0.5 * choice_loss | |
if not return_dict: | |
output = ( | |
start_logits, | |
end_logits, | |
has_logits, | |
) + outputs[2:] # hidden_states, attentions | |
return ((total_loss,) + output) if total_loss is not None else output | |
return QuestionAnsweringNaModelOutput( | |
loss=total_loss, | |
start_logits=start_logits, | |
end_logits=end_logits, | |
has_logits=has_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |