|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch Data2VecAudio model.""" |
|
|
|
import math |
|
import warnings |
|
from typing import Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
|
|
from ...activations import ACT2FN |
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled |
|
from ...modeling_outputs import ( |
|
BaseModelOutput, |
|
CausalLMOutput, |
|
SequenceClassifierOutput, |
|
TokenClassifierOutput, |
|
Wav2Vec2BaseModelOutput, |
|
XVectorOutput, |
|
) |
|
from ...modeling_utils import PreTrainedModel |
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging |
|
from .configuration_data2vec_audio import Data2VecAudioConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_HIDDEN_STATES_START_POSITION = 2 |
|
|
|
|
|
_CONFIG_FOR_DOC = "Data2VecAudioConfig" |
|
|
|
|
|
_CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h" |
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] |
|
|
|
|
|
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" |
|
_CTC_EXPECTED_LOSS = 66.95 |
|
|
|
|
|
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
|
"facebook/data2vec-audio-base", |
|
"facebook/data2vec-audio-base-10m", |
|
"facebook/data2vec-audio-base-100h", |
|
"facebook/data2vec-audio-base-960h", |
|
|
|
] |
|
|
|
|
|
|
|
def _compute_mask_indices( |
|
shape: Tuple[int, int], |
|
mask_prob: float, |
|
mask_length: int, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
min_masks: int = 0, |
|
) -> np.ndarray: |
|
""" |
|
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for |
|
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on |
|
CPU as part of the preprocessing during training. |
|
|
|
Args: |
|
shape: The shape for which to compute masks. This should be of a tuple of size 2 where |
|
the first element is the batch size and the second element is the length of the axis to span. |
|
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of |
|
independently generated mask spans of length `mask_length` is computed by |
|
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the |
|
actual percentage will be smaller. |
|
mask_length: size of the mask |
|
min_masks: minimum number of masked spans |
|
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of |
|
each batch dimension. |
|
""" |
|
batch_size, sequence_length = shape |
|
|
|
if mask_length < 1: |
|
raise ValueError("`mask_length` has to be bigger than 0.") |
|
|
|
if mask_length > sequence_length: |
|
raise ValueError( |
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" |
|
f" and `sequence_length`: {sequence_length}`" |
|
) |
|
|
|
|
|
epsilon = np.random.rand(1).item() |
|
|
|
def compute_num_masked_span(input_length): |
|
"""Given input length, compute how many spans should be masked""" |
|
num_masked_span = int(mask_prob * input_length / mask_length + epsilon) |
|
num_masked_span = max(num_masked_span, min_masks) |
|
|
|
|
|
if num_masked_span * mask_length > sequence_length: |
|
num_masked_span = sequence_length // mask_length |
|
|
|
|
|
if input_length - (mask_length - 1) < num_masked_span: |
|
num_masked_span = max(input_length - (mask_length - 1), 0) |
|
|
|
return num_masked_span |
|
|
|
|
|
input_lengths = ( |
|
attention_mask.sum(-1).detach().tolist() |
|
if attention_mask is not None |
|
else [sequence_length for _ in range(batch_size)] |
|
) |
|
|
|
|
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) |
|
spec_aug_mask_idxs = [] |
|
|
|
max_num_masked_span = compute_num_masked_span(sequence_length) |
|
|
|
if max_num_masked_span == 0: |
|
return spec_aug_mask |
|
|
|
for input_length in input_lengths: |
|
|
|
num_masked_span = compute_num_masked_span(input_length) |
|
|
|
|
|
spec_aug_mask_idx = np.random.choice( |
|
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False |
|
) |
|
|
|
|
|
|
|
|
|
if len(spec_aug_mask_idx) == 0: |
|
|
|
|
|
|
|
dummy_mask_idx = sequence_length - 1 |
|
else: |
|
dummy_mask_idx = spec_aug_mask_idx[0] |
|
|
|
spec_aug_mask_idx = np.concatenate( |
|
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] |
|
) |
|
spec_aug_mask_idxs.append(spec_aug_mask_idx) |
|
|
|
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) |
|
|
|
|
|
spec_aug_mask_idxs = np.broadcast_to( |
|
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) |
|
) |
|
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) |
|
|
|
|
|
offsets = np.arange(mask_length)[None, None, :] |
|
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( |
|
batch_size, max_num_masked_span * mask_length |
|
) |
|
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets |
|
|
|
|
|
if spec_aug_mask_idxs.max() > sequence_length - 1: |
|
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 |
|
|
|
|
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) |
|
|
|
return spec_aug_mask |
|
|
|
|
|
class Data2VecAudioConvLayer(nn.Module): |
|
def __init__(self, config, layer_id=0): |
|
super().__init__() |
|
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 |
|
self.out_conv_dim = config.conv_dim[layer_id] |
|
|
|
self.conv = nn.Conv1d( |
|
self.in_conv_dim, |
|
self.out_conv_dim, |
|
kernel_size=config.conv_kernel[layer_id], |
|
stride=config.conv_stride[layer_id], |
|
bias=config.conv_bias, |
|
) |
|
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) |
|
self.activation = ACT2FN[config.feat_extract_activation] |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.conv(hidden_states) |
|
|
|
hidden_states = hidden_states.transpose(-2, -1) |
|
hidden_states = self.layer_norm(hidden_states) |
|
hidden_states = hidden_states.transpose(-2, -1) |
|
|
|
hidden_states = self.activation(hidden_states) |
|
return hidden_states |
|
|
|
|
|
|
|
class Data2VecAudioPadLayer(nn.Module): |
|
def __init__(self, num_conv_pos_embeddings): |
|
super().__init__() |
|
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 |
|
|
|
def forward(self, hidden_states): |
|
if self.num_pad_remove > 0: |
|
hidden_states = hidden_states[:, :, : -self.num_pad_remove] |
|
return hidden_states |
|
|
|
|
|
class Data2VecAudioPositionalConvLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.conv = nn.Conv1d( |
|
config.hidden_size, |
|
config.hidden_size, |
|
kernel_size=config.conv_pos_kernel_size, |
|
padding=config.conv_pos_kernel_size // 2, |
|
groups=config.num_conv_pos_embedding_groups, |
|
) |
|
|
|
self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size) |
|
self.activation = ACT2FN[config.feat_extract_activation] |
|
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.conv(hidden_states) |
|
hidden_states = self.padding(hidden_states) |
|
|
|
hidden_states = hidden_states.transpose(1, 2) |
|
hidden_states = self.layer_norm(hidden_states) |
|
hidden_states = hidden_states.transpose(1, 2) |
|
hidden_states = self.activation(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class Data2VecAudioPositionalConvEmbedding(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.layers = nn.ModuleList( |
|
[Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)] |
|
) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = hidden_states.transpose(1, 2) |
|
for layer in self.layers: |
|
hidden_states = layer(hidden_states) |
|
hidden_states = hidden_states.transpose(1, 2) |
|
return hidden_states |
|
|
|
|
|
class Data2VecAudioFeatureEncoder(nn.Module): |
|
"""Construct the features from raw audio waveform""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.conv_layers = nn.ModuleList( |
|
[Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] |
|
) |
|
self.gradient_checkpointing = False |
|
self._requires_grad = True |
|
|
|
|
|
def _freeze_parameters(self): |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
self._requires_grad = False |
|
|
|
|
|
def forward(self, input_values): |
|
hidden_states = input_values[:, None] |
|
|
|
|
|
if self._requires_grad and self.training: |
|
hidden_states.requires_grad = True |
|
|
|
for conv_layer in self.conv_layers: |
|
if self._requires_grad and self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(conv_layer), |
|
hidden_states, |
|
) |
|
else: |
|
hidden_states = conv_layer(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
class Data2VecAudioFeatureProjection(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) |
|
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) |
|
self.dropout = nn.Dropout(config.feat_proj_dropout) |
|
|
|
def forward(self, hidden_states): |
|
|
|
norm_hidden_states = self.layer_norm(hidden_states) |
|
hidden_states = self.projection(norm_hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
return hidden_states, norm_hidden_states |
|
|
|
|
|
|
|
class Data2VecAudioAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
dropout: float = 0.0, |
|
is_decoder: bool = False, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.head_dim = embed_dim // num_heads |
|
|
|
if (self.head_dim * num_heads) != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
|
f" and `num_heads`: {num_heads})." |
|
) |
|
self.scaling = self.head_dim**-0.5 |
|
self.is_decoder = is_decoder |
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
key_value_states: Optional[torch.Tensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
layer_head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) * self.scaling |
|
|
|
|
|
|
|
|
|
if ( |
|
is_cross_attention |
|
and past_key_value is not None |
|
and past_key_value[0].shape[2] == key_value_states.shape[1] |
|
): |
|
|
|
key_states = past_key_value[0] |
|
value_states = past_key_value[1] |
|
elif is_cross_attention: |
|
|
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz) |
|
elif past_key_value is not None: |
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
else: |
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
|
if self.is_decoder: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_key_value = (key_states, value_states) |
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
|
key_states = key_states.reshape(*proj_shape) |
|
value_states = value_states.reshape(*proj_shape) |
|
|
|
src_len = key_states.size(1) |
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
|
) |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
|
if layer_head_mask is not None: |
|
if layer_head_mask.size() != (self.num_heads,): |
|
raise ValueError( |
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" |
|
f" {layer_head_mask.size()}" |
|
) |
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
if output_attentions: |
|
|
|
|
|
|
|
|
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
|
else: |
|
attn_weights_reshaped = None |
|
|
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
|
attn_output = torch.bmm(attn_probs, value_states) |
|
|
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
|
attn_output = attn_output.transpose(1, 2) |
|
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights_reshaped, past_key_value |
|
|
|
|
|
|
|
class Data2VecAudioFeedForward(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.intermediate_dropout = nn.Dropout(config.activation_dropout) |
|
|
|
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) |
|
if isinstance(config.hidden_act, str): |
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|
else: |
|
self.intermediate_act_fn = config.hidden_act |
|
|
|
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) |
|
self.output_dropout = nn.Dropout(config.hidden_dropout) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.intermediate_dense(hidden_states) |
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
hidden_states = self.intermediate_dropout(hidden_states) |
|
|
|
hidden_states = self.output_dense(hidden_states) |
|
hidden_states = self.output_dropout(hidden_states) |
|
return hidden_states |
|
|
|
|
|
|
|
class Data2VecAudioEncoderLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.attention = Data2VecAudioAttention( |
|
embed_dim=config.hidden_size, |
|
num_heads=config.num_attention_heads, |
|
dropout=config.attention_dropout, |
|
is_decoder=False, |
|
) |
|
self.dropout = nn.Dropout(config.hidden_dropout) |
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.feed_forward = Data2VecAudioFeedForward(config) |
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
|
attn_residual = hidden_states |
|
hidden_states, attn_weights, _ = self.attention( |
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions |
|
) |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = attn_residual + hidden_states |
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
hidden_states = hidden_states + self.feed_forward(hidden_states) |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
|
|
class Data2VecAudioEncoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.pos_conv_embed = Data2VecAudioPositionalConvEmbedding(config) |
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout) |
|
self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attentions = () if output_attentions else None |
|
|
|
if attention_mask is not None: |
|
|
|
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) |
|
hidden_states[~expand_attention_mask] = 0 |
|
|
|
|
|
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) |
|
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min |
|
attention_mask = attention_mask.expand( |
|
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] |
|
) |
|
|
|
position_embeddings = self.pos_conv_embed(hidden_states) |
|
hidden_states = hidden_states + position_embeddings |
|
hidden_states = self.layer_norm(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() |
|
|
|
for layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
dropout_probability = torch.rand([]) |
|
|
|
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False |
|
if not skip_the_layer or deepspeed_zero3_is_enabled: |
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs, output_attentions) |
|
|
|
return custom_forward |
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(layer), |
|
hidden_states, |
|
attention_mask, |
|
) |
|
else: |
|
layer_outputs = layer( |
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if skip_the_layer: |
|
layer_outputs = (None, None) |
|
|
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
) |
|
|
|
|
|
|
|
class Data2VecAudioAdapter(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
|
|
if config.output_hidden_size != config.hidden_size: |
|
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) |
|
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) |
|
else: |
|
self.proj = self.proj_layer_norm = None |
|
|
|
self.layers = nn.ModuleList(Data2VecAudioAdapterLayer(config) for _ in range(config.num_adapter_layers)) |
|
self.layerdrop = config.layerdrop |
|
|
|
def forward(self, hidden_states): |
|
|
|
if self.proj is not None and self.proj_layer_norm is not None: |
|
hidden_states = self.proj(hidden_states) |
|
hidden_states = self.proj_layer_norm(hidden_states) |
|
|
|
hidden_states = hidden_states.transpose(1, 2) |
|
|
|
for layer in self.layers: |
|
layerdrop_prob = np.random.random() |
|
if not self.training or (layerdrop_prob > self.layerdrop): |
|
hidden_states = layer(hidden_states) |
|
|
|
hidden_states = hidden_states.transpose(1, 2) |
|
return hidden_states |
|
|
|
|
|
|
|
class Data2VecAudioAdapterLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.conv = nn.Conv1d( |
|
config.output_hidden_size, |
|
2 * config.output_hidden_size, |
|
config.adapter_kernel_size, |
|
stride=config.adapter_stride, |
|
padding=1, |
|
) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.conv(hidden_states) |
|
hidden_states = nn.functional.glu(hidden_states, dim=1) |
|
|
|
return hidden_states |
|
|
|
|
|
class Data2VecAudioPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = Data2VecAudioConfig |
|
base_model_prefix = "data2vec_audio" |
|
main_input_name = "input_values" |
|
supports_gradient_checkpointing = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, Data2VecAudioFeatureProjection): |
|
k = math.sqrt(1 / module.projection.in_features) |
|
nn.init.uniform_(module.projection.weight, a=-k, b=k) |
|
nn.init.uniform_(module.projection.bias, a=-k, b=k) |
|
elif isinstance(module, Data2VecAudioPositionalConvLayer): |
|
nn.init.constant_(module.conv.bias, 0) |
|
elif isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
if module.weight is not None: |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Conv1d): |
|
nn.init.kaiming_normal_(module.weight) |
|
|
|
if module.bias is not None: |
|
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) |
|
nn.init.uniform_(module.bias, a=-k, b=k) |
|
|
|
|
|
def _get_feat_extract_output_lengths( |
|
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None |
|
): |
|
""" |
|
Computes the output length of the convolutional layers |
|
""" |
|
|
|
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter |
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
|
|
|
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 |
|
|
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
|
if add_adapter: |
|
for _ in range(self.config.num_adapter_layers): |
|
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) |
|
|
|
return input_lengths |
|
|
|
|
|
def _get_feature_vector_attention_mask( |
|
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None |
|
): |
|
|
|
|
|
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] |
|
|
|
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) |
|
output_lengths = output_lengths.to(torch.long) |
|
|
|
batch_size = attention_mask.shape[0] |
|
|
|
attention_mask = torch.zeros( |
|
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device |
|
) |
|
|
|
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 |
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() |
|
return attention_mask |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
DATA2VEC_AUDIO_START_DOCSTRING = r""" |
|
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and |
|
Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and |
|
Michael Auli. |
|
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving etc.). |
|
|
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use |
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and |
|
behavior. |
|
|
|
Parameters: |
|
config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
DATA2VEC_AUDIO_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): |
|
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file |
|
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install |
|
soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and |
|
conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details. |
|
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, |
|
1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
|
<Tip warning={true}> |
|
|
|
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == |
|
True`. For all models whose processor has `config.return_attention_mask == False`, such as |
|
[data2vec-audio-base](https://huggingface.co/facebook/data2vec-audio-base-960h), `attention_mask` should |
|
**not** be passed to avoid degraded performance when doing batched inference. For such models |
|
`input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these |
|
models also yield slightly different results depending on whether `input_values` is padded or not. |
|
|
|
</Tip> |
|
|
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.", |
|
DATA2VEC_AUDIO_START_DOCSTRING, |
|
) |
|
class Data2VecAudioModel(Data2VecAudioPreTrainedModel): |
|
def __init__(self, config: Data2VecAudioConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.feature_extractor = Data2VecAudioFeatureEncoder(config) |
|
self.feature_projection = Data2VecAudioFeatureProjection(config) |
|
|
|
|
|
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: |
|
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) |
|
|
|
self.encoder = Data2VecAudioEncoder(config) |
|
|
|
self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None |
|
|
|
|
|
self.post_init() |
|
|
|
def freeze_feature_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
self.feature_extractor._freeze_parameters() |
|
|
|
def _mask_hidden_states( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
mask_time_indices: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
): |
|
""" |
|
Masks extracted features along time axis and/or along feature axis according to |
|
[SpecAugment](https://arxiv.org/abs/1904.08779). |
|
""" |
|
|
|
|
|
if not getattr(self.config, "apply_spec_augment", True): |
|
return hidden_states |
|
|
|
|
|
batch_size, sequence_length, hidden_size = hidden_states.size() |
|
|
|
if mask_time_indices is not None: |
|
|
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) |
|
elif self.config.mask_time_prob > 0 and self.training: |
|
mask_time_indices = _compute_mask_indices( |
|
(batch_size, sequence_length), |
|
mask_prob=self.config.mask_time_prob, |
|
mask_length=self.config.mask_time_length, |
|
attention_mask=attention_mask, |
|
min_masks=self.config.mask_time_min_masks, |
|
) |
|
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) |
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) |
|
|
|
if self.config.mask_feature_prob > 0 and self.training: |
|
|
|
mask_feature_indices = _compute_mask_indices( |
|
(batch_size, hidden_size), |
|
mask_prob=self.config.mask_feature_prob, |
|
mask_length=self.config.mask_feature_length, |
|
min_masks=self.config.mask_feature_min_masks, |
|
) |
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) |
|
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) |
|
hidden_states[mask_feature_indices] = 0 |
|
|
|
return hidden_states |
|
|
|
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=Wav2Vec2BaseModelOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
modality="audio", |
|
expected_output=_EXPECTED_OUTPUT_SHAPE, |
|
) |
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
mask_time_indices: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, Wav2Vec2BaseModelOutput]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
extract_features = self.feature_extractor(input_values) |
|
extract_features = extract_features.transpose(1, 2) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = self._get_feature_vector_attention_mask( |
|
extract_features.shape[1], attention_mask, add_adapter=False |
|
) |
|
|
|
hidden_states, extract_features = self.feature_projection(extract_features) |
|
hidden_states = self._mask_hidden_states( |
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask |
|
) |
|
|
|
encoder_outputs = self.encoder( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
if self.adapter is not None: |
|
hidden_states = self.adapter(hidden_states) |
|
|
|
if not return_dict: |
|
return (hidden_states, extract_features) + encoder_outputs[1:] |
|
|
|
return Wav2Vec2BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
extract_features=extract_features, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
"""Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", |
|
DATA2VEC_AUDIO_START_DOCSTRING, |
|
) |
|
class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.data2vec_audio = Data2VecAudioModel(config) |
|
self.dropout = nn.Dropout(config.final_dropout) |
|
|
|
if config.vocab_size is None: |
|
raise ValueError( |
|
f"You are trying to instantiate {self.__class__} with a configuration that " |
|
"does not define the vocabulary size of the language model head. Please " |
|
"instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. " |
|
"or define `vocab_size` of your model's configuration." |
|
) |
|
output_hidden_size = ( |
|
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size |
|
) |
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) |
|
|
|
|
|
self.post_init() |
|
|
|
def freeze_feature_extractor(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
warnings.warn( |
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5." |
|
"Please use the equivalent `freeze_feature_encoder` method instead.", |
|
FutureWarning, |
|
) |
|
self.freeze_feature_encoder() |
|
|
|
def freeze_feature_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
self.data2vec_audio.feature_extractor._freeze_parameters() |
|
|
|
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=CausalLMOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
expected_output=_CTC_EXPECTED_OUTPUT, |
|
expected_loss=_CTC_EXPECTED_LOSS, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
) -> Union[Tuple, CausalLMOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): |
|
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to |
|
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. |
|
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., |
|
config.vocab_size - 1]`. |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.data2vec_audio( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
if labels.max() >= self.config.vocab_size: |
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") |
|
|
|
|
|
attention_mask = ( |
|
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) |
|
) |
|
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) |
|
|
|
|
|
|
|
labels_mask = labels >= 0 |
|
target_lengths = labels_mask.sum(-1) |
|
flattened_targets = labels.masked_select(labels_mask) |
|
|
|
|
|
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) |
|
|
|
with torch.backends.cudnn.flags(enabled=False): |
|
loss = nn.functional.ctc_loss( |
|
log_probs, |
|
flattened_targets, |
|
input_lengths, |
|
target_lengths, |
|
blank=self.config.pad_token_id, |
|
reduction=self.config.ctc_loss_reduction, |
|
zero_infinity=self.config.ctc_zero_infinity, |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CausalLMOutput( |
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks |
|
like SUPERB Keyword Spotting. |
|
""", |
|
DATA2VEC_AUDIO_START_DOCSTRING, |
|
) |
|
class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
if hasattr(config, "add_adapter") and config.add_adapter: |
|
raise ValueError( |
|
"Sequence classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)" |
|
) |
|
self.data2vec_audio = Data2VecAudioModel(config) |
|
num_layers = config.num_hidden_layers + 1 |
|
if config.use_weighted_layer_sum: |
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) |
|
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def freeze_feature_extractor(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will |
|
not be updated during training. |
|
""" |
|
warnings.warn( |
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5." |
|
"Please use the equivalent `freeze_feature_encoder` method instead.", |
|
FutureWarning, |
|
) |
|
self.freeze_feature_encoder() |
|
|
|
def freeze_feature_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
self.data2vec_audio.feature_extractor._freeze_parameters() |
|
|
|
def freeze_base_model(self): |
|
""" |
|
Calling this function will disable the gradient computation for the base model so that its parameters will not |
|
be updated during training. Only the classification head will be updated. |
|
""" |
|
for param in self.data2vec_audio.parameters(): |
|
param.requires_grad = False |
|
|
|
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=SequenceClassifierOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
modality="audio", |
|
) |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
) -> Union[Tuple, SequenceClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
|
|
|
outputs = self.data2vec_audio( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if self.config.use_weighted_layer_sum: |
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION] |
|
hidden_states = torch.stack(hidden_states, dim=1) |
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
else: |
|
hidden_states = outputs[0] |
|
|
|
hidden_states = self.projector(hidden_states) |
|
if attention_mask is None: |
|
pooled_output = hidden_states.mean(dim=1) |
|
else: |
|
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) |
|
hidden_states[~padding_mask] = 0.0 |
|
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) |
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization. |
|
""", |
|
DATA2VEC_AUDIO_START_DOCSTRING, |
|
) |
|
class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
if hasattr(config, "add_adapter") and config.add_adapter: |
|
raise ValueError( |
|
"Audio frame classification does not support the use of Data2VecAudio adapters" |
|
" (config.add_adapter=True)" |
|
) |
|
self.data2vec_audio = Data2VecAudioModel(config) |
|
num_layers = config.num_hidden_layers + 1 |
|
if config.use_weighted_layer_sum: |
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
self.num_labels = config.num_labels |
|
|
|
self.init_weights() |
|
|
|
def freeze_feature_extractor(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
warnings.warn( |
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5." |
|
"Please use the equivalent `freeze_feature_encoder` method instead.", |
|
FutureWarning, |
|
) |
|
self.freeze_feature_encoder() |
|
|
|
def freeze_feature_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
self.data2vec_audio.feature_extractor._freeze_parameters() |
|
|
|
def freeze_base_model(self): |
|
""" |
|
Calling this function will disable the gradient computation for the base model so that its parameters will not |
|
be updated during training. Only the classification head will be updated. |
|
""" |
|
for param in self.data2vec_audio.parameters(): |
|
param.requires_grad = False |
|
|
|
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=TokenClassifierOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
modality="audio", |
|
) |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, TokenClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
|
|
|
outputs = self.data2vec_audio( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if self.config.use_weighted_layer_sum: |
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION] |
|
hidden_states = torch.stack(hidden_states, dim=1) |
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
else: |
|
hidden_states = outputs[0] |
|
|
|
logits = self.classifier(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
|
return output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
|
|
class AMSoftmaxLoss(nn.Module): |
|
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): |
|
super(AMSoftmaxLoss, self).__init__() |
|
self.scale = scale |
|
self.margin = margin |
|
self.num_labels = num_labels |
|
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) |
|
self.loss = nn.CrossEntropyLoss() |
|
|
|
def forward(self, hidden_states, labels): |
|
labels = labels.flatten() |
|
weight = nn.functional.normalize(self.weight, dim=0) |
|
hidden_states = nn.functional.normalize(hidden_states, dim=1) |
|
cos_theta = torch.mm(hidden_states, weight) |
|
psi = cos_theta - self.margin |
|
|
|
onehot = nn.functional.one_hot(labels, self.num_labels) |
|
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) |
|
loss = self.loss(logits, labels) |
|
|
|
return loss |
|
|
|
|
|
|
|
class TDNNLayer(nn.Module): |
|
def __init__(self, config, layer_id=0): |
|
super().__init__() |
|
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] |
|
self.out_conv_dim = config.tdnn_dim[layer_id] |
|
self.kernel_size = config.tdnn_kernel[layer_id] |
|
self.dilation = config.tdnn_dilation[layer_id] |
|
|
|
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) |
|
self.activation = nn.ReLU() |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = hidden_states.unsqueeze(1) |
|
hidden_states = nn.functional.unfold( |
|
hidden_states, |
|
(self.kernel_size, self.in_conv_dim), |
|
stride=(1, self.in_conv_dim), |
|
dilation=(self.dilation, 1), |
|
) |
|
hidden_states = hidden_states.transpose(1, 2) |
|
hidden_states = self.kernel(hidden_states) |
|
|
|
hidden_states = self.activation(hidden_states) |
|
return hidden_states |
|
|
|
|
|
@add_start_docstrings( |
|
""" |
|
Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification. |
|
""", |
|
DATA2VEC_AUDIO_START_DOCSTRING, |
|
) |
|
class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.data2vec_audio = Data2VecAudioModel(config) |
|
num_layers = config.num_hidden_layers + 1 |
|
if config.use_weighted_layer_sum: |
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) |
|
|
|
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] |
|
self.tdnn = nn.ModuleList(tdnn_layers) |
|
|
|
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) |
|
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) |
|
|
|
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def freeze_feature_extractor(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
warnings.warn( |
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5." |
|
"Please use the equivalent `freeze_feature_encoder` method instead.", |
|
FutureWarning, |
|
) |
|
self.freeze_feature_encoder() |
|
|
|
def freeze_feature_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
self.data2vec_audio.feature_extractor._freeze_parameters() |
|
|
|
def freeze_base_model(self): |
|
""" |
|
Calling this function will disable the gradient computation for the base model so that its parameters will not |
|
be updated during training. Only the classification head will be updated. |
|
""" |
|
for param in self.data2vec_audio.parameters(): |
|
param.requires_grad = False |
|
|
|
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): |
|
""" |
|
Computes the output length of the TDNN layers |
|
""" |
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
|
|
|
return (input_length - kernel_size) // stride + 1 |
|
|
|
for kernel_size in self.config.tdnn_kernel: |
|
input_lengths = _conv_out_length(input_lengths, kernel_size, 1) |
|
|
|
return input_lengths |
|
|
|
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING) |
|
@add_code_sample_docstrings( |
|
checkpoint=_CHECKPOINT_FOR_DOC, |
|
output_type=XVectorOutput, |
|
config_class=_CONFIG_FOR_DOC, |
|
modality="audio", |
|
) |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
) -> Union[Tuple, XVectorOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
|
|
|
outputs = self.data2vec_audio( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if self.config.use_weighted_layer_sum: |
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION] |
|
hidden_states = torch.stack(hidden_states, dim=1) |
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
else: |
|
hidden_states = outputs[0] |
|
|
|
hidden_states = self.projector(hidden_states) |
|
|
|
for tdnn_layer in self.tdnn: |
|
hidden_states = tdnn_layer(hidden_states) |
|
|
|
|
|
if attention_mask is None: |
|
mean_features = hidden_states.mean(dim=1) |
|
std_features = hidden_states.std(dim=1) |
|
else: |
|
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) |
|
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) |
|
mean_features = [] |
|
std_features = [] |
|
for i, length in enumerate(tdnn_output_lengths): |
|
mean_features.append(hidden_states[i, :length].mean(dim=0)) |
|
std_features.append(hidden_states[i, :length].std(dim=0)) |
|
mean_features = torch.stack(mean_features) |
|
std_features = torch.stack(std_features) |
|
statistic_pooling = torch.cat([mean_features, std_features], dim=-1) |
|
|
|
output_embeddings = self.feature_extractor(statistic_pooling) |
|
logits = self.classifier(output_embeddings) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.objective(logits, labels) |
|
|
|
if not return_dict: |
|
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return XVectorOutput( |
|
loss=loss, |
|
logits=logits, |
|
embeddings=output_embeddings, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|