|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch ALBERT model.""" |
|
|
|
import math |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from ...activations import ACT2FN |
|
from ...modeling_outputs import ( |
|
BaseModelOutput, |
|
BaseModelOutputWithPooling, |
|
MaskedLMOutput, |
|
MultipleChoiceModelOutput, |
|
QuestionAnsweringModelOutput, |
|
SequenceClassifierOutput, |
|
TokenClassifierOutput, |
|
) |
|
from ...modeling_utils import PreTrainedModel |
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer |
|
from ...utils import ( |
|
ModelOutput, |
|
add_code_sample_docstrings, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from .configuration_albert import AlbertConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CHECKPOINT_FOR_DOC = "albert-base-v2" |
|
_CONFIG_FOR_DOC = "AlbertConfig" |
|
|
|
|
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
|
"albert-base-v1", |
|
"albert-large-v1", |
|
"albert-xlarge-v1", |
|
"albert-xxlarge-v1", |
|
"albert-base-v2", |
|
"albert-large-v2", |
|
"albert-xlarge-v2", |
|
"albert-xxlarge-v2", |
|
|
|
] |
|
|
|
|
|
def load_tf_weights_in_albert(model, config, tf_checkpoint_path): |
|
"""Load tf checkpoints in a pytorch model.""" |
|
try: |
|
import re |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
except ImportError: |
|
logger.error( |
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " |
|
"https://www.tensorflow.org/install/ for installation instructions." |
|
) |
|
raise |
|
tf_path = os.path.abspath(tf_checkpoint_path) |
|
logger.info(f"Converting TensorFlow checkpoint from {tf_path}") |
|
|
|
init_vars = tf.train.list_variables(tf_path) |
|
names = [] |
|
arrays = [] |
|
for name, shape in init_vars: |
|
logger.info(f"Loading TF weight {name} with shape {shape}") |
|
array = tf.train.load_variable(tf_path, name) |
|
names.append(name) |
|
arrays.append(array) |
|
|
|
for name, array in zip(names, arrays): |
|
print(name) |
|
|
|
for name, array in zip(names, arrays): |
|
original_name = name |
|
|
|
|
|
name = name.replace("module/", "") |
|
|
|
|
|
name = name.replace("ffn_1", "ffn") |
|
name = name.replace("bert/", "albert/") |
|
name = name.replace("attention_1", "attention") |
|
name = name.replace("transform/", "") |
|
name = name.replace("LayerNorm_1", "full_layer_layer_norm") |
|
name = name.replace("LayerNorm", "attention/LayerNorm") |
|
name = name.replace("transformer/", "") |
|
|
|
|
|
name = name.replace("intermediate/dense/", "") |
|
name = name.replace("ffn/intermediate/output/dense/", "ffn_output/") |
|
|
|
|
|
name = name.replace("/output/", "/") |
|
name = name.replace("/self/", "/") |
|
|
|
|
|
name = name.replace("pooler/dense", "pooler") |
|
|
|
|
|
name = name.replace("cls/predictions", "predictions") |
|
name = name.replace("predictions/attention", "predictions") |
|
|
|
|
|
name = name.replace("embeddings/attention", "embeddings") |
|
name = name.replace("inner_group_", "albert_layers/") |
|
name = name.replace("group_", "albert_layer_groups/") |
|
|
|
|
|
if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name): |
|
name = "classifier/" + name |
|
|
|
|
|
if "seq_relationship" in name: |
|
name = name.replace("seq_relationship/output_", "sop_classifier/classifier/") |
|
name = name.replace("weights", "weight") |
|
|
|
name = name.split("/") |
|
|
|
|
|
if ( |
|
"adam_m" in name |
|
or "adam_v" in name |
|
or "AdamWeightDecayOptimizer" in name |
|
or "AdamWeightDecayOptimizer_1" in name |
|
or "global_step" in name |
|
): |
|
logger.info(f"Skipping {'/'.join(name)}") |
|
continue |
|
|
|
pointer = model |
|
for m_name in name: |
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name): |
|
scope_names = re.split(r"_(\d+)", m_name) |
|
else: |
|
scope_names = [m_name] |
|
|
|
if scope_names[0] == "kernel" or scope_names[0] == "gamma": |
|
pointer = getattr(pointer, "weight") |
|
elif scope_names[0] == "output_bias" or scope_names[0] == "beta": |
|
pointer = getattr(pointer, "bias") |
|
elif scope_names[0] == "output_weights": |
|
pointer = getattr(pointer, "weight") |
|
elif scope_names[0] == "squad": |
|
pointer = getattr(pointer, "classifier") |
|
else: |
|
try: |
|
pointer = getattr(pointer, scope_names[0]) |
|
except AttributeError: |
|
logger.info(f"Skipping {'/'.join(name)}") |
|
continue |
|
if len(scope_names) >= 2: |
|
num = int(scope_names[1]) |
|
pointer = pointer[num] |
|
|
|
if m_name[-11:] == "_embeddings": |
|
pointer = getattr(pointer, "weight") |
|
elif m_name == "kernel": |
|
array = np.transpose(array) |
|
try: |
|
if pointer.shape != array.shape: |
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") |
|
except ValueError as e: |
|
e.args += (pointer.shape, array.shape) |
|
raise |
|
print(f"Initialize PyTorch weight {name} from {original_name}") |
|
pointer.data = torch.from_numpy(array) |
|
|
|
return model |
|
|
|
|
|
class AlbertEmbeddings(nn.Module): |
|
""" |
|
Construct the embeddings from word, position and token_type embeddings. |
|
""" |
|
|
|
def __init__(self, config: AlbertConfig): |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) |
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) |
|
|
|
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.register_buffer( |
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False |
|
) |
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
|
self.register_buffer( |
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False |
|
) |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
past_key_values_length: int = 0, |
|
) -> torch.Tensor: |
|
if input_ids is not None: |
|
input_shape = input_ids.size() |
|
else: |
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
seq_length = input_shape[1] |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] |
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
if hasattr(self, "token_type_ids"): |
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) |
|
token_type_ids = buffered_token_type_ids_expanded |
|
else: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
if self.position_embedding_type == "absolute": |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings += position_embeddings |
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class AlbertAttention(nn.Module): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__() |
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
|
raise ValueError( |
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
|
f"heads ({config.num_attention_heads}" |
|
) |
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
self.hidden_size = config.hidden_size |
|
self.attention_head_size = config.hidden_size // config.num_attention_heads |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.key = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
|
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
self.output_dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.pruned_heads = set() |
|
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) |
|
|
|
|
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: |
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
|
x = x.view(new_x_shape) |
|
return x.permute(0, 2, 1, 3) |
|
|
|
def prune_heads(self, heads: List[int]) -> None: |
|
if len(heads) == 0: |
|
return |
|
heads, index = find_pruneable_heads_and_indices( |
|
heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads |
|
) |
|
|
|
|
|
self.query = prune_linear_layer(self.query, index) |
|
self.key = prune_linear_layer(self.key, index) |
|
self.value = prune_linear_layer(self.value, index) |
|
self.dense = prune_linear_layer(self.dense, index, dim=1) |
|
|
|
|
|
self.num_attention_heads = self.num_attention_heads - len(heads) |
|
self.all_head_size = self.attention_head_size * self.num_attention_heads |
|
self.pruned_heads = self.pruned_heads.union(heads) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: bool = False, |
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: |
|
mixed_query_layer = self.query(hidden_states) |
|
mixed_key_layer = self.key(hidden_states) |
|
mixed_value_layer = self.value(hidden_states) |
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer) |
|
key_layer = self.transpose_for_scores(mixed_key_layer) |
|
value_layer = self.transpose_for_scores(mixed_value_layer) |
|
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_scores = attention_scores + attention_mask |
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
seq_length = hidden_states.size()[1] |
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
|
distance = position_ids_l - position_ids_r |
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
|
if self.position_embedding_type == "relative_key": |
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
attention_scores = attention_scores + relative_position_scores |
|
elif self.position_embedding_type == "relative_key_query": |
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
|
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
|
|
|
|
|
|
attention_probs = self.attention_dropout(attention_probs) |
|
|
|
|
|
if head_mask is not None: |
|
attention_probs = attention_probs * head_mask |
|
|
|
context_layer = torch.matmul(attention_probs, value_layer) |
|
context_layer = context_layer.transpose(2, 1).flatten(2) |
|
|
|
projected_context_layer = self.dense(context_layer) |
|
projected_context_layer_dropout = self.output_dropout(projected_context_layer) |
|
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout) |
|
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,) |
|
|
|
|
|
class AlbertLayer(nn.Module): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__() |
|
|
|
self.config = config |
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward |
|
self.seq_len_dim = 1 |
|
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.attention = AlbertAttention(config) |
|
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) |
|
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) |
|
self.activation = ACT2FN[config.hidden_act] |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) |
|
|
|
ffn_output = apply_chunking_to_forward( |
|
self.ff_chunk, |
|
self.chunk_size_feed_forward, |
|
self.seq_len_dim, |
|
attention_output[0], |
|
) |
|
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) |
|
|
|
return (hidden_states,) + attention_output[1:] |
|
|
|
def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor: |
|
ffn_output = self.ffn(attention_output) |
|
ffn_output = self.activation(ffn_output) |
|
ffn_output = self.ffn_output(ffn_output) |
|
return ffn_output |
|
|
|
|
|
class AlbertLayerGroup(nn.Module): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__() |
|
|
|
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)]) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: |
|
layer_hidden_states = () |
|
layer_attentions = () |
|
|
|
for layer_index, albert_layer in enumerate(self.albert_layers): |
|
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions) |
|
hidden_states = layer_output[0] |
|
|
|
if output_attentions: |
|
layer_attentions = layer_attentions + (layer_output[1],) |
|
|
|
if output_hidden_states: |
|
layer_hidden_states = layer_hidden_states + (hidden_states,) |
|
|
|
outputs = (hidden_states,) |
|
if output_hidden_states: |
|
outputs = outputs + (layer_hidden_states,) |
|
if output_attentions: |
|
outputs = outputs + (layer_attentions,) |
|
return outputs |
|
|
|
|
|
class AlbertTransformer(nn.Module): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__() |
|
|
|
self.config = config |
|
self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size) |
|
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
) -> Union[BaseModelOutput, Tuple]: |
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states) |
|
|
|
all_hidden_states = (hidden_states,) if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask |
|
|
|
for i in range(self.config.num_hidden_layers): |
|
|
|
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) |
|
|
|
|
|
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) |
|
|
|
layer_group_output = self.albert_layer_groups[group_idx]( |
|
hidden_states, |
|
attention_mask, |
|
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group], |
|
output_attentions, |
|
output_hidden_states, |
|
) |
|
hidden_states = layer_group_output[0] |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + layer_group_output[-1] |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions |
|
) |
|
|
|
|
|
class AlbertPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = AlbertConfig |
|
load_tf_weights = load_tf_weights_in_albert |
|
base_model_prefix = "albert" |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights.""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
@dataclass |
|
class AlbertForPreTrainingOutput(ModelOutput): |
|
""" |
|
Output type of [`AlbertForPreTraining`]. |
|
|
|
Args: |
|
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
|
Total loss as the sum of the masked language modeling loss and the next sequence prediction |
|
(classification) loss. |
|
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): |
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation |
|
before SoftMax). |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of |
|
shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
prediction_logits: torch.FloatTensor = None |
|
sop_logits: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
ALBERT_START_DOCSTRING = r""" |
|
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Args: |
|
config ([`AlbertConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
ALBERT_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `({0})`): |
|
Indices of input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and |
|
[`PreTrainedTokenizer.encode`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, |
|
1]`: |
|
|
|
- 0 corresponds to a *sentence A* token, |
|
- 1 corresponds to a *sentence B* token. |
|
|
|
[What are token type IDs?](../glossary#token-type-ids) |
|
position_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.max_position_embeddings - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.", |
|
ALBERT_START_DOCSTRING, |
|
) |
|
class AlbertModel(AlbertPreTrainedModel): |
|
config_class = AlbertConfig |
|
base_model_prefix = "albert" |
|
|
|
def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True): |
|
super().__init__(config) |
|
|
|
self.config = config |
|
self.embeddings = AlbertEmbeddings(config) |
|
self.encoder = AlbertTransformer(config) |
|
if add_pooling_layer: |
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.pooler_activation = nn.Tanh() |
|
else: |
|
self.pooler = None |
|
self.pooler_activation = None |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self) -> nn.Embedding: |
|
return self.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value: nn.Embedding) -> None: |
|
self.embeddings.word_embeddings = value |
|
|
|
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: |
|
""" |
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has |
|
a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT |
|
model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers. |
|
|
|
These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer, |
|
while [2,3] correspond to the two inner groups of the second hidden layer. |
|
|
|
Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more |
|
information about head pruning |
|
""" |
|
for layer, heads in heads_to_prune.items(): |
|
group_idx = int(layer / self.config.inner_group_num) |
|
inner_group_idx = int(layer - group_idx * self.config.inner_group_num) |
|
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads) |
|
|
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=BaseModelOutputWithPooling, |
|
config_class=_CONFIG_FOR_DOC, |
|
) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[BaseModelOutputWithPooling, Tuple]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
input_shape = input_ids.size() |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
batch_size, seq_length = input_shape |
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(input_shape, device=device) |
|
if token_type_ids is None: |
|
if hasattr(self.embeddings, "token_type_ids"): |
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) |
|
token_type_ids = buffered_token_type_ids_expanded |
|
else: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) |
|
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min |
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
embedding_output = self.embeddings( |
|
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds |
|
) |
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
extended_attention_mask, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = encoder_outputs[0] |
|
|
|
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None |
|
|
|
if not return_dict: |
|
return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a |
|
`sentence order prediction (classification)` head. |
|
""", |
|
ALBERT_START_DOCSTRING, |
|
) |
|
class AlbertForPreTraining(AlbertPreTrainedModel): |
|
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] |
|
|
|
def __init__(self, config: AlbertConfig): |
|
super().__init__(config) |
|
|
|
self.albert = AlbertModel(config) |
|
self.predictions = AlbertMLMHead(config) |
|
self.sop_classifier = AlbertSOPHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_output_embeddings(self) -> nn.Linear: |
|
return self.predictions.decoder |
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: |
|
self.predictions.decoder = new_embeddings |
|
|
|
def get_input_embeddings(self) -> nn.Embedding: |
|
return self.albert.embeddings.word_embeddings |
|
|
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
|
@replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
sentence_order_label: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[AlbertForPreTrainingOutput, Tuple]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the |
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
|
sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair |
|
(see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then |
|
sequence B), `1` indicates switched order (sequence B, then sequence A). |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AlbertForPreTraining |
|
>>> import torch |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") |
|
>>> model = AlbertForPreTraining.from_pretrained("albert-base-v2") |
|
|
|
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) |
|
>>> # Batch size 1 |
|
>>> outputs = model(input_ids) |
|
|
|
>>> prediction_logits = outputs.prediction_logits |
|
>>> sop_logits = outputs.sop_logits |
|
```""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.albert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
|
|
prediction_scores = self.predictions(sequence_output) |
|
sop_scores = self.sop_classifier(pooled_output) |
|
|
|
total_loss = None |
|
if labels is not None and sentence_order_label is not None: |
|
loss_fct = CrossEntropyLoss() |
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1)) |
|
total_loss = masked_lm_loss + sentence_order_loss |
|
|
|
if not return_dict: |
|
output = (prediction_scores, sop_scores) + outputs[2:] |
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
return AlbertForPreTrainingOutput( |
|
loss=total_loss, |
|
prediction_logits=prediction_scores, |
|
sop_logits=sop_scores, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class AlbertMLMHead(nn.Module): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__() |
|
|
|
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) |
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
|
self.dense = nn.Linear(config.hidden_size, config.embedding_size) |
|
self.decoder = nn.Linear(config.embedding_size, config.vocab_size) |
|
self.activation = ACT2FN[config.hidden_act] |
|
self.decoder.bias = self.bias |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.activation(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states) |
|
hidden_states = self.decoder(hidden_states) |
|
|
|
prediction_scores = hidden_states |
|
|
|
return prediction_scores |
|
|
|
def _tie_weights(self) -> None: |
|
|
|
self.bias = self.decoder.bias |
|
|
|
|
|
class AlbertSOPHead(nn.Module): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__() |
|
|
|
self.dropout = nn.Dropout(config.classifier_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: |
|
dropout_pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(dropout_pooled_output) |
|
return logits |
|
|
|
|
|
@add_start_docstrings( |
|
"Albert Model with a `language modeling` head on top.", |
|
ALBERT_START_DOCSTRING, |
|
) |
|
class AlbertForMaskedLM(AlbertPreTrainedModel): |
|
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.albert = AlbertModel(config, add_pooling_layer=False) |
|
self.predictions = AlbertMLMHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_output_embeddings(self) -> nn.Linear: |
|
return self.predictions.decoder |
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: |
|
self.predictions.decoder = new_embeddings |
|
|
|
def get_input_embeddings(self) -> nn.Embedding: |
|
return self.albert.embeddings.word_embeddings |
|
|
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
|
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[MaskedLMOutput, Tuple]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the |
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> import torch |
|
>>> from transformers import AutoTokenizer, AlbertForMaskedLM |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") |
|
>>> model = AlbertForMaskedLM.from_pretrained("albert-base-v2") |
|
|
|
>>> # add mask_token |
|
>>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt") |
|
>>> with torch.no_grad(): |
|
... logits = model(**inputs).logits |
|
|
|
>>> # retrieve index of [MASK] |
|
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] |
|
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) |
|
>>> tokenizer.decode(predicted_token_id) |
|
'france' |
|
``` |
|
|
|
```python |
|
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] |
|
>>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) |
|
>>> outputs = model(**inputs, labels=labels) |
|
>>> round(outputs.loss.item(), 2) |
|
0.81 |
|
``` |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.albert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_outputs = outputs[0] |
|
|
|
prediction_scores = self.predictions(sequence_outputs) |
|
|
|
masked_lm_loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (prediction_scores,) + outputs[2:] |
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
|
return MaskedLMOutput( |
|
loss=masked_lm_loss, |
|
logits=prediction_scores, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled |
|
output) e.g. for GLUE tasks. |
|
""", |
|
ALBERT_START_DOCSTRING, |
|
) |
|
class AlbertForSequenceClassification(AlbertPreTrainedModel): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.albert = AlbertModel(config) |
|
self.dropout = nn.Dropout(config.classifier_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
|
@add_code_sample_docstrings( |
|
checkpoint="textattack/albert-base-v2-imdb", |
|
output_type=SequenceClassifierOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
expected_output="'LABEL_1'", |
|
expected_loss=0.12, |
|
) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[SequenceClassifierOutput, Tuple]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.albert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = outputs[1] |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for |
|
Named-Entity-Recognition (NER) tasks. |
|
""", |
|
ALBERT_START_DOCSTRING, |
|
) |
|
class AlbertForTokenClassification(AlbertPreTrainedModel): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.albert = AlbertModel(config, add_pooling_layer=False) |
|
classifier_dropout_prob = ( |
|
config.classifier_dropout_prob |
|
if config.classifier_dropout_prob is not None |
|
else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=TokenClassifierOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[TokenClassifierOutput, Tuple]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.albert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear |
|
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). |
|
""", |
|
ALBERT_START_DOCSTRING, |
|
) |
|
class AlbertForQuestionAnswering(AlbertPreTrainedModel): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.albert = AlbertModel(config, add_pooling_layer=False) |
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
|
@add_code_sample_docstrings( |
|
checkpoint="twmkn9/albert-base-v2-squad2", |
|
output_type=QuestionAnsweringModelOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
qa_target_start_index=12, |
|
qa_target_end_index=13, |
|
expected_output="'a nice puppet'", |
|
expected_loss=7.36, |
|
) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
start_positions: Optional[torch.LongTensor] = None, |
|
end_positions: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[AlbertForPreTrainingOutput, Tuple]: |
|
r""" |
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for position (index) of the start of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
|
are not taken into account for computing the loss. |
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for position (index) of the end of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
|
are not taken into account for computing the loss. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.albert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
logits: torch.Tensor = self.qa_outputs(sequence_output) |
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
start_logits = start_logits.squeeze(-1).contiguous() |
|
end_logits = end_logits.squeeze(-1).contiguous() |
|
|
|
total_loss = None |
|
if start_positions is not None and end_positions is not None: |
|
|
|
if len(start_positions.size()) > 1: |
|
start_positions = start_positions.squeeze(-1) |
|
if len(end_positions.size()) > 1: |
|
end_positions = end_positions.squeeze(-1) |
|
|
|
ignored_index = start_logits.size(1) |
|
start_positions = start_positions.clamp(0, ignored_index) |
|
end_positions = end_positions.clamp(0, ignored_index) |
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
|
start_loss = loss_fct(start_logits, start_positions) |
|
end_loss = loss_fct(end_logits, end_positions) |
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
if not return_dict: |
|
output = (start_logits, end_logits) + outputs[2:] |
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
return QuestionAnsweringModelOutput( |
|
loss=total_loss, |
|
start_logits=start_logits, |
|
end_logits=end_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a |
|
softmax) e.g. for RocStories/SWAG tasks. |
|
""", |
|
ALBERT_START_DOCSTRING, |
|
) |
|
class AlbertForMultipleChoice(AlbertPreTrainedModel): |
|
def __init__(self, config: AlbertConfig): |
|
super().__init__(config) |
|
|
|
self.albert = AlbertModel(config) |
|
self.dropout = nn.Dropout(config.classifier_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, 1) |
|
|
|
|
|
self.post_init() |
|
|
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=MultipleChoiceModelOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[AlbertForPreTrainingOutput, Tuple]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., |
|
num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see |
|
*input_ids* above) |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
|
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None |
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None |
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None |
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None |
|
inputs_embeds = ( |
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) |
|
if inputs_embeds is not None |
|
else None |
|
) |
|
outputs = self.albert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = outputs[1] |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits: torch.Tensor = self.classifier(pooled_output) |
|
reshaped_logits = logits.view(-1, num_choices) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(reshaped_logits, labels) |
|
|
|
if not return_dict: |
|
output = (reshaped_logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return MultipleChoiceModelOutput( |
|
loss=loss, |
|
logits=reshaped_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|