jasspier commited on
Commit
2ebf0bb
1 Parent(s): ca625d0

Create wav2vec2.py

Browse files
Files changed (1) hide show
  1. wav2vec2.py +1499 -0
wav2vec2.py ADDED
@@ -0,0 +1,1499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from typing import List, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from fairseq import utils
16
+ from fairseq.data.data_utils import compute_mask_indices
17
+ from fairseq.dataclass import ChoiceEnum, FairseqDataclass
18
+ from fairseq.distributed import fsdp_wrap
19
+ from fairseq.models import BaseFairseqModel, register_model
20
+ from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel
21
+ from fairseq.modules import (
22
+ Fp32GroupNorm,
23
+ Fp32LayerNorm,
24
+ GradMultiply,
25
+ GumbelVectorQuantizer,
26
+ LayerNorm,
27
+ MultiheadAttention,
28
+ RelPositionalEncoding,
29
+ SamePad,
30
+ TransposeLast,
31
+ )
32
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
33
+ from fairseq.modules.conformer_layer import ConformerWav2Vec2EncoderLayer
34
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
35
+ from fairseq.utils import buffered_arange, index_put, is_xla_tensor
36
+
37
+ from fairseq.models.wav2vec.utils import pad_to_multiple
38
+
39
+ EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
40
+ MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
41
+ LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"])
42
+
43
+
44
+ @dataclass
45
+ class Wav2Vec2Config(FairseqDataclass):
46
+ extractor_mode: EXTRACTOR_MODE_CHOICES = field(
47
+ default="default",
48
+ metadata={
49
+ "help": "mode for feature extractor. default has a single group norm with d "
50
+ "groups in the first conv block, whereas layer_norm has layer norms in "
51
+ "every block (meant to use with normalize=True)"
52
+ },
53
+ )
54
+ encoder_layers: int = field(
55
+ default=12, metadata={"help": "num encoder layers in the transformer"}
56
+ )
57
+ encoder_embed_dim: int = field(
58
+ default=768, metadata={"help": "encoder embedding dimension"}
59
+ )
60
+ encoder_ffn_embed_dim: int = field(
61
+ default=3072, metadata={"help": "encoder embedding dimension for FFN"}
62
+ )
63
+ encoder_attention_heads: int = field(
64
+ default=12, metadata={"help": "num encoder attention heads"}
65
+ )
66
+ activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
67
+ default="gelu", metadata={"help": "activation function to use"}
68
+ )
69
+ layer_type: LAYER_TYPE_CHOICES = field(
70
+ default="transformer", metadata={"help": "layer type in encoder"}
71
+ )
72
+ # dropouts
73
+ dropout: float = field(
74
+ default=0.1, metadata={"help": "dropout probability for the transformer"}
75
+ )
76
+ attention_dropout: float = field(
77
+ default=0.1, metadata={"help": "dropout probability for attention weights"}
78
+ )
79
+ activation_dropout: float = field(
80
+ default=0.0, metadata={"help": "dropout probability after activation in FFN"}
81
+ )
82
+ encoder_layerdrop: float = field(
83
+ default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"}
84
+ )
85
+ dropout_input: float = field(
86
+ default=0.0,
87
+ metadata={"help": "dropout to apply to the input (after feat extr)"},
88
+ )
89
+ dropout_features: float = field(
90
+ default=0.0,
91
+ metadata={"help": "dropout to apply to the features (after feat extr)"},
92
+ )
93
+
94
+ final_dim: int = field(
95
+ default=0,
96
+ metadata={
97
+ "help": "project final representations and targets to this many dimensions."
98
+ "set to encoder_embed_dim is <= 0"
99
+ },
100
+ )
101
+ layer_norm_first: bool = field(
102
+ default=False, metadata={"help": "apply layernorm first in the transformer"}
103
+ )
104
+ input_feature_ndim: int = field(
105
+ default=40,
106
+ metadata={"help": "number of mfcc/fbank feature dimensions, e.g. 40"}
107
+ )
108
+ conv_feature_layers: str = field(
109
+ default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
110
+ metadata={
111
+ "help": "string describing convolutional feature extraction layers in form of a python list that contains "
112
+ "[(dim, kernel_size, stride), ...]"
113
+ },
114
+ )
115
+ conv_bias: bool = field(
116
+ default=False, metadata={"help": "include bias in conv encoder"}
117
+ )
118
+ logit_temp: float = field(
119
+ default=0.1, metadata={"help": "temperature to divide logits by"}
120
+ )
121
+ quantize_targets: bool = field(
122
+ default=False, metadata={"help": "use quantized targets"}
123
+ )
124
+ quantize_input: bool = field(
125
+ default=False, metadata={"help": "use quantized inputs"}
126
+ )
127
+ same_quantizer: bool = field(
128
+ default=False, metadata={"help": "use same quantizer for inputs and targets"}
129
+ )
130
+ target_glu: bool = field(
131
+ default=False, metadata={"help": "adds projection + glu to targets"}
132
+ )
133
+ feature_grad_mult: float = field(
134
+ default=1.0, metadata={"help": "multiply feature extractor var grads by this"}
135
+ )
136
+ quantizer_depth: int = field(
137
+ default=1,
138
+ metadata={"help": "number of quantizer layers"},
139
+ )
140
+ quantizer_factor: int = field(
141
+ default=3,
142
+ metadata={
143
+ "help": "dimensionality increase for inner quantizer layers (if depth > 1)"
144
+ },
145
+ )
146
+ latent_vars: int = field(
147
+ default=320,
148
+ metadata={"help": "number of latent variables V in each group of the codebook"},
149
+ )
150
+ latent_groups: int = field(
151
+ default=2,
152
+ metadata={"help": "number of groups G of latent variables in the codebook"},
153
+ )
154
+ latent_dim: int = field(
155
+ default=0,
156
+ metadata={
157
+ "help": "if > 0, uses this dimensionality for latent variables. "
158
+ "otherwise uses final_dim / latent_groups"
159
+ },
160
+ )
161
+
162
+ # masking
163
+ mask_length: int = field(default=10, metadata={"help": "mask length"})
164
+ mask_prob: float = field(
165
+ default=0.65, metadata={"help": "probability of replacing a token with mask"}
166
+ )
167
+ mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
168
+ default="static", metadata={"help": "how to choose mask length"}
169
+ )
170
+ mask_other: float = field(
171
+ default=0,
172
+ metadata={
173
+ "help": "secondary mask argument (used for more complex distributions), "
174
+ "see help in compute_mask_indices"
175
+ },
176
+ )
177
+ no_mask_overlap: bool = field(
178
+ default=False, metadata={"help": "whether to allow masks to overlap"}
179
+ )
180
+ mask_min_space: int = field(
181
+ default=1,
182
+ metadata={"help": "min space between spans (if no overlap is enabled)"},
183
+ )
184
+ require_same_masks: bool = field(
185
+ default=True,
186
+ metadata={
187
+ "help": "whether to number of masked timesteps must be the same across all "
188
+ "examples in a batch"
189
+ },
190
+ )
191
+ mask_dropout: float = field(
192
+ default=0.0,
193
+ metadata={"help": "percent of masks to unmask for each sample"},
194
+ )
195
+
196
+ # channel masking
197
+ mask_channel_length: int = field(
198
+ default=10, metadata={"help": "length of the mask for features (channels)"}
199
+ )
200
+ mask_channel_prob: float = field(
201
+ default=0.0, metadata={"help": "probability of replacing a feature with 0"}
202
+ )
203
+ mask_channel_before: bool = False
204
+ mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
205
+ default="static",
206
+ metadata={"help": "how to choose mask length for channel masking"},
207
+ )
208
+ mask_channel_other: float = field(
209
+ default=0,
210
+ metadata={
211
+ "help": "secondary mask argument (used for more complex distributions), "
212
+ "see help in compute_mask_indicesh"
213
+ },
214
+ )
215
+ no_mask_channel_overlap: bool = field(
216
+ default=False, metadata={"help": "whether to allow channel masks to overlap"}
217
+ )
218
+ mask_channel_min_space: int = field(
219
+ default=1,
220
+ metadata={"help": "min space between spans (if no overlap is enabled)"},
221
+ )
222
+
223
+ # negative selection
224
+ num_negatives: int = field(
225
+ default=100,
226
+ metadata={"help": "number of negative examples from the same sample"},
227
+ )
228
+ negatives_from_everywhere: bool = field(
229
+ default=False,
230
+ metadata={"help": "sample negatives from everywhere, not just masked states"},
231
+ )
232
+ cross_sample_negatives: int = field(
233
+ default=0, metadata={"help": "number of negative examples from the any sample"}
234
+ )
235
+ codebook_negatives: int = field(
236
+ default=0, metadata={"help": "number of negative examples codebook"}
237
+ )
238
+
239
+ # positional embeddings
240
+ conv_pos: int = field(
241
+ default=128,
242
+ metadata={"help": "number of filters for convolutional positional embeddings"},
243
+ )
244
+ conv_pos_groups: int = field(
245
+ default=16,
246
+ metadata={"help": "number of groups for convolutional positional embedding"},
247
+ )
248
+ pos_conv_depth: int = field(
249
+ default=1,
250
+ metadata={"help": "depth of positional encoder network"},
251
+ )
252
+
253
+ latent_temp: Tuple[float, float, float] = field(
254
+ default=(2, 0.5, 0.999995),
255
+ metadata={
256
+ "help": "temperature for latent variable sampling. "
257
+ "can be tuple of 3 values (start, end, decay)"
258
+ },
259
+ )
260
+ max_positions: int = field(default=100000, metadata={"help": "Max positions"})
261
+ checkpoint_activations: bool = field(
262
+ default=False,
263
+ metadata={"help": "recompute activations and save memory for extra compute"},
264
+ )
265
+
266
+ # FP16 optimization
267
+ required_seq_len_multiple: int = field(
268
+ default=2,
269
+ metadata={
270
+ "help": "pad the input to encoder such that the sequence length is divisible by multiple"
271
+ },
272
+ )
273
+ crop_seq_to_multiple: int = field(
274
+ default=1,
275
+ metadata={
276
+ "help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple"
277
+ },
278
+ )
279
+
280
+ # Conformer
281
+ depthwise_conv_kernel_size: int = field(
282
+ default=31,
283
+ metadata={
284
+ "help": "depthwise-conv-kernel-size for convolution in conformer layer"
285
+ },
286
+ )
287
+ attn_type: str = field(
288
+ default="",
289
+ metadata={"help": "if espnet use ESPNET MHA"},
290
+ )
291
+ pos_enc_type: str = field(
292
+ default="abs",
293
+ metadata={"help": "Positional encoding type to use in conformer"},
294
+ )
295
+ fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})
296
+
297
+ # Adapter num
298
+ adp_num: int = field(
299
+ default=-1
300
+ )
301
+ adp_dim: int = field(
302
+ default=64
303
+ )
304
+ adp_act_fn: str = field(
305
+ default="relu"
306
+ )
307
+ adp_trf_idx: str = field(
308
+ default="all",
309
+ )
310
+
311
+
312
+ @register_model("wav2vec2", dataclass=Wav2Vec2Config)
313
+ class Wav2Vec2Model(BaseFairseqModel):
314
+ def __init__(self, cfg: Wav2Vec2Config):
315
+ super().__init__()
316
+ self.cfg = cfg
317
+
318
+ feature_enc_layers = eval(cfg.conv_feature_layers)
319
+ self.embed = feature_enc_layers[-1][0]
320
+
321
+ self.feature_extractor = ConvFeatureExtractionModel(
322
+ conv_layers=feature_enc_layers,
323
+ dropout=0.0,
324
+ mode=cfg.extractor_mode,
325
+ conv_bias=cfg.conv_bias,
326
+ input_feature_ndim=cfg.input_feature_ndim
327
+ )
328
+
329
+ self.post_extract_proj = (
330
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
331
+ if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input
332
+ else None
333
+ )
334
+
335
+ self.crop_seq_to_multiple = cfg.crop_seq_to_multiple
336
+
337
+ self.mask_prob = cfg.mask_prob
338
+ self.mask_selection = cfg.mask_selection
339
+ self.mask_other = cfg.mask_other
340
+ self.mask_length = cfg.mask_length
341
+ self.no_mask_overlap = cfg.no_mask_overlap
342
+ self.mask_min_space = cfg.mask_min_space
343
+
344
+ self.mask_channel_prob = cfg.mask_channel_prob
345
+ self.mask_channel_before = cfg.mask_channel_before
346
+ self.mask_channel_selection = cfg.mask_channel_selection
347
+ self.mask_channel_other = cfg.mask_channel_other
348
+ self.mask_channel_length = cfg.mask_channel_length
349
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
350
+ self.mask_channel_min_space = cfg.mask_channel_min_space
351
+
352
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
353
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
354
+
355
+ self.feature_grad_mult = cfg.feature_grad_mult
356
+
357
+ self.quantizer = None
358
+ self.input_quantizer = None
359
+
360
+ self.n_negatives = cfg.num_negatives
361
+ self.cross_sample_negatives = cfg.cross_sample_negatives
362
+ self.codebook_negatives = cfg.codebook_negatives
363
+ self.negatives_from_everywhere = cfg.negatives_from_everywhere
364
+
365
+ self.logit_temp = cfg.logit_temp
366
+
367
+ final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
368
+
369
+ if cfg.quantize_targets:
370
+ vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim
371
+ self.quantizer = GumbelVectorQuantizer(
372
+ dim=self.embed,
373
+ num_vars=cfg.latent_vars,
374
+ temp=cfg.latent_temp,
375
+ groups=cfg.latent_groups,
376
+ combine_groups=False,
377
+ vq_dim=vq_dim,
378
+ time_first=True,
379
+ weight_proj_depth=cfg.quantizer_depth,
380
+ weight_proj_factor=cfg.quantizer_factor,
381
+ )
382
+ self.project_q = nn.Linear(vq_dim, final_dim)
383
+ else:
384
+ self.project_q = nn.Linear(self.embed, final_dim)
385
+
386
+ if cfg.quantize_input:
387
+ if cfg.same_quantizer and self.quantizer is not None:
388
+ vq_dim = final_dim
389
+ self.input_quantizer = self.quantizer
390
+ else:
391
+ vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim
392
+ self.input_quantizer = GumbelVectorQuantizer(
393
+ dim=self.embed,
394
+ num_vars=cfg.latent_vars,
395
+ temp=cfg.latent_temp,
396
+ groups=cfg.latent_groups,
397
+ combine_groups=False,
398
+ vq_dim=vq_dim,
399
+ time_first=True,
400
+ weight_proj_depth=cfg.quantizer_depth,
401
+ weight_proj_factor=cfg.quantizer_factor,
402
+ )
403
+ self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim)
404
+
405
+ self.mask_emb = nn.Parameter(
406
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
407
+ )
408
+ encoder_cls = TransformerEncoder
409
+ if cfg.layer_type == "conformer" and cfg.pos_enc_type in ["rel_pos", "rope"]:
410
+ encoder_cls = ConformerEncoder
411
+
412
+ self.encoder = encoder_cls(cfg)
413
+ self.layer_norm = LayerNorm(self.embed)
414
+
415
+ self.target_glu = None
416
+ if cfg.target_glu:
417
+ self.target_glu = nn.Sequential(
418
+ nn.Linear(final_dim, final_dim * 2), nn.GLU()
419
+ )
420
+
421
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
422
+
423
+ def upgrade_state_dict_named(self, state_dict, name):
424
+ super().upgrade_state_dict_named(state_dict, name)
425
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
426
+ return state_dict
427
+
428
+ @classmethod
429
+ def build_model(cls, cfg: Wav2Vec2Config, task=None):
430
+ """Build a new model instance."""
431
+
432
+ return cls(cfg)
433
+
434
+ def apply_mask(
435
+ self,
436
+ x,
437
+ padding_mask,
438
+ mask_indices=None,
439
+ mask_channel_indices=None,
440
+ ):
441
+ B, T, C = x.shape
442
+
443
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
444
+ mask_channel_indices = compute_mask_indices(
445
+ (B, C),
446
+ None,
447
+ self.mask_channel_prob,
448
+ self.mask_channel_length,
449
+ self.mask_channel_selection,
450
+ self.mask_channel_other,
451
+ no_overlap=self.no_mask_channel_overlap,
452
+ min_space=self.mask_channel_min_space,
453
+ )
454
+ mask_channel_indices = (
455
+ torch.from_numpy(mask_channel_indices)
456
+ .to(x.device)
457
+ .unsqueeze(1)
458
+ .expand(-1, T, -1)
459
+ )
460
+ x[mask_channel_indices] = 0
461
+
462
+ if self.mask_prob > 0:
463
+ if mask_indices is None:
464
+ mask_indices = compute_mask_indices(
465
+ (B, T),
466
+ padding_mask,
467
+ self.mask_prob,
468
+ self.mask_length,
469
+ self.mask_selection,
470
+ self.mask_other,
471
+ min_masks=2,
472
+ no_overlap=self.no_mask_overlap,
473
+ min_space=self.mask_min_space,
474
+ require_same_masks=self.cfg.require_same_masks,
475
+ mask_dropout=self.cfg.mask_dropout,
476
+ )
477
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
478
+ x = index_put(x, mask_indices, self.mask_emb)
479
+ else:
480
+ mask_indices = None
481
+
482
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
483
+ if mask_channel_indices is None:
484
+ mask_channel_indices = compute_mask_indices(
485
+ (B, C),
486
+ None,
487
+ self.mask_channel_prob,
488
+ self.mask_channel_length,
489
+ self.mask_channel_selection,
490
+ self.mask_channel_other,
491
+ no_overlap=self.no_mask_channel_overlap,
492
+ min_space=self.mask_channel_min_space,
493
+ )
494
+ mask_channel_indices = (
495
+ torch.from_numpy(mask_channel_indices)
496
+ .to(x.device)
497
+ .unsqueeze(1)
498
+ .expand(-1, T, -1)
499
+ )
500
+ x = index_put(x, mask_channel_indices, 0)
501
+
502
+ return x, mask_indices
503
+
504
+ def sample_negatives(self, y, num, padding_count=None):
505
+
506
+ if self.n_negatives == 0 and self.cross_sample_negatives == 0:
507
+ return y.new(0)
508
+
509
+ bsz, tsz, fsz = y.shape
510
+ y = y.view(-1, fsz) # BTC => (BxT)C
511
+
512
+ # FIXME: what happens if padding_count is specified?
513
+ cross_high = tsz * bsz
514
+ high = tsz - (padding_count or 0)
515
+ with torch.no_grad():
516
+ assert high > 1, f"{bsz,tsz,fsz}"
517
+
518
+ if self.n_negatives > 0:
519
+ tszs = (
520
+ buffered_arange(num)
521
+ .unsqueeze(-1)
522
+ .expand(-1, self.n_negatives)
523
+ .flatten()
524
+ )
525
+
526
+ neg_idxs = torch.randint(
527
+ low=0, high=high - 1, size=(bsz, self.n_negatives * num)
528
+ )
529
+ neg_idxs[neg_idxs >= tszs] += 1
530
+
531
+ if self.cross_sample_negatives > 0:
532
+ tszs = (
533
+ buffered_arange(num)
534
+ .unsqueeze(-1)
535
+ .expand(-1, self.cross_sample_negatives)
536
+ .flatten()
537
+ )
538
+
539
+ cross_neg_idxs = torch.randint(
540
+ low=0,
541
+ high=cross_high - 1,
542
+ size=(bsz, self.cross_sample_negatives * num),
543
+ )
544
+ cross_neg_idxs[cross_neg_idxs >= tszs] += 1
545
+
546
+ if self.n_negatives > 0:
547
+ neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high)
548
+ else:
549
+ neg_idxs = cross_neg_idxs
550
+
551
+ if self.cross_sample_negatives > 0 and self.n_negatives > 0:
552
+ neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)
553
+
554
+ negs = y[neg_idxs.view(-1)]
555
+ negs = negs.view(
556
+ bsz, num, self.n_negatives + self.cross_sample_negatives, fsz
557
+ ).permute(
558
+ 2, 0, 1, 3
559
+ ) # to NxBxTxC
560
+ return negs, neg_idxs
561
+
562
+ def compute_preds(self, x, y, negatives):
563
+
564
+ neg_is_pos = (y == negatives).all(-1)
565
+ y = y.unsqueeze(0)
566
+ targets = torch.cat([y, negatives], dim=0)
567
+
568
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1)
569
+ logits = logits / self.logit_temp
570
+ logits = logits.type_as(x)
571
+
572
+ if is_xla_tensor(logits) or neg_is_pos.any():
573
+ if not hasattr(self, "_inftensor"):
574
+ fillval = -float(2**30)
575
+ self._inftensor = (
576
+ torch.tensor(fillval).to(x.device)
577
+ if is_xla_tensor(logits)
578
+ else float("-inf")
579
+ )
580
+ logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor)
581
+
582
+ return logits
583
+
584
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
585
+ """
586
+ Computes the output length of the convolutional layers
587
+ """
588
+
589
+ def _conv_out_length(input_length, kernel_size, stride):
590
+ return torch.floor((input_length - kernel_size) / stride + 1)
591
+
592
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
593
+
594
+ for i in range(len(conv_cfg_list)):
595
+ input_lengths = _conv_out_length(
596
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
597
+ )
598
+
599
+ return input_lengths.to(torch.long)
600
+
601
+ def forward(
602
+ self,
603
+ source,
604
+ padding_mask=None,
605
+ mask=True,
606
+ features_only=False,
607
+ layer=None,
608
+ mask_indices=None,
609
+ mask_channel_indices=None,
610
+ padding_count=None,
611
+ corpus_key=None,
612
+ ):
613
+
614
+ if self.feature_grad_mult > 0:
615
+ features = self.feature_extractor(source)
616
+ if self.feature_grad_mult != 1.0:
617
+ features = GradMultiply.apply(features, self.feature_grad_mult)
618
+ else:
619
+ with torch.no_grad():
620
+ features = self.feature_extractor(source)
621
+
622
+ features_pen = features.float().pow(2).mean()
623
+
624
+ features = features.transpose(1, 2)
625
+ features = self.layer_norm(features)
626
+ unmasked_features = features.clone()
627
+
628
+ if padding_mask is not None and padding_mask.any():
629
+ input_lengths = (1 - padding_mask.long()).sum(-1)
630
+ # apply conv formula to get real output_lengths
631
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
632
+
633
+ padding_mask = torch.zeros(
634
+ features.shape[:2], dtype=features.dtype, device=features.device
635
+ )
636
+
637
+ # these two operations makes sure that all values
638
+ # before the output lengths indices are attended to
639
+ padding_mask[
640
+ (
641
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
642
+ output_lengths - 1,
643
+ )
644
+ ] = 1
645
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
646
+ else:
647
+ padding_mask = None
648
+
649
+ time_steps_to_drop = features.size(1) % self.crop_seq_to_multiple
650
+ if time_steps_to_drop != 0:
651
+ features = features[:, :-time_steps_to_drop]
652
+ unmasked_features = unmasked_features[:, :-time_steps_to_drop]
653
+ if padding_mask is not None:
654
+ padding_mask = padding_mask[:, :-time_steps_to_drop]
655
+
656
+ if self.post_extract_proj is not None:
657
+ features = self.post_extract_proj(features)
658
+
659
+ features = self.dropout_input(features)
660
+ unmasked_features = self.dropout_features(unmasked_features)
661
+
662
+ num_vars = None
663
+ code_ppl = None
664
+ prob_ppl = None
665
+ curr_temp = None
666
+
667
+ if self.input_quantizer:
668
+ q = self.input_quantizer(features, produce_targets=False)
669
+ features = q["x"]
670
+ num_vars = q["num_vars"]
671
+ code_ppl = q["code_perplexity"]
672
+ prob_ppl = q["prob_perplexity"]
673
+ curr_temp = q["temp"]
674
+ features = self.project_inp(features)
675
+
676
+ if mask:
677
+ x, mask_indices = self.apply_mask(
678
+ features,
679
+ padding_mask,
680
+ mask_indices=mask_indices,
681
+ mask_channel_indices=mask_channel_indices,
682
+ )
683
+ if not is_xla_tensor(x) and mask_indices is not None:
684
+ # tpu-comment: reducing the size in a dynamic way causes
685
+ # too many recompilations on xla.
686
+ y = unmasked_features[mask_indices].view(
687
+ unmasked_features.size(0), -1, unmasked_features.size(-1)
688
+ )
689
+ else:
690
+ y = unmasked_features
691
+ else:
692
+ x = features
693
+ y = unmasked_features
694
+ mask_indices = None
695
+
696
+ x, layer_results = self.encoder(
697
+ x, padding_mask=padding_mask, layer=layer, corpus_key=corpus_key
698
+ )
699
+
700
+ if features_only:
701
+ return {
702
+ "x": x,
703
+ "padding_mask": padding_mask,
704
+ "features": unmasked_features,
705
+ "layer_results": layer_results,
706
+ }
707
+
708
+ if self.quantizer:
709
+ if self.negatives_from_everywhere:
710
+ q = self.quantizer(unmasked_features, produce_targets=False)
711
+ y = q["x"]
712
+ num_vars = q["num_vars"]
713
+ code_ppl = q["code_perplexity"]
714
+ prob_ppl = q["prob_perplexity"]
715
+ curr_temp = q["temp"]
716
+ y = self.project_q(y)
717
+
718
+ negs, _ = self.sample_negatives(
719
+ y,
720
+ mask_indices[0].sum(),
721
+ padding_count=padding_count,
722
+ )
723
+ y = y[mask_indices].view(y.size(0), -1, y.size(-1))
724
+
725
+ else:
726
+ q = self.quantizer(y, produce_targets=False)
727
+ y = q["x"]
728
+ num_vars = q["num_vars"]
729
+ code_ppl = q["code_perplexity"]
730
+ prob_ppl = q["prob_perplexity"]
731
+ curr_temp = q["temp"]
732
+
733
+ y = self.project_q(y)
734
+
735
+ negs, _ = self.sample_negatives(
736
+ y,
737
+ y.size(1),
738
+ padding_count=padding_count,
739
+ )
740
+
741
+ if self.codebook_negatives > 0:
742
+ cb_negs = self.quantizer.sample_from_codebook(
743
+ y.size(0) * y.size(1), self.codebook_negatives
744
+ )
745
+ cb_negs = cb_negs.view(
746
+ self.codebook_negatives, y.size(0), y.size(1), -1
747
+ ) # order doesnt matter
748
+ cb_negs = self.project_q(cb_negs)
749
+ negs = torch.cat([negs, cb_negs], dim=0)
750
+ else:
751
+ y = self.project_q(y)
752
+
753
+ if self.negatives_from_everywhere:
754
+ negs, _ = self.sample_negatives(
755
+ unmasked_features,
756
+ y.size(1),
757
+ padding_count=padding_count,
758
+ )
759
+ negs = self.project_q(negs)
760
+ else:
761
+ negs, _ = self.sample_negatives(
762
+ y,
763
+ y.size(1),
764
+ padding_count=padding_count,
765
+ )
766
+
767
+ if not is_xla_tensor(x):
768
+ # tpu-comment: reducing the size in a dynamic way causes
769
+ # too many recompilations on xla.
770
+ x = x[mask_indices].view(x.size(0), -1, x.size(-1))
771
+
772
+ if self.target_glu:
773
+ y = self.target_glu(y)
774
+ negs = self.target_glu(negs)
775
+
776
+ x = self.final_proj(x)
777
+ x = self.compute_preds(x, y, negs)
778
+
779
+ result = {
780
+ "x": x,
781
+ "padding_mask": padding_mask,
782
+ "features_pen": features_pen,
783
+ }
784
+
785
+ if prob_ppl is not None:
786
+ result["prob_perplexity"] = prob_ppl
787
+ result["code_perplexity"] = code_ppl
788
+ result["num_vars"] = num_vars
789
+ result["temp"] = curr_temp
790
+
791
+ return result
792
+
793
+ def quantize(self, x):
794
+ assert self.quantizer is not None
795
+ x = self.feature_extractor(x)
796
+ x = x.transpose(1, 2)
797
+ x = self.layer_norm(x)
798
+ return self.quantizer.forward_idx(x)
799
+
800
+ def extract_features(
801
+ self, source, padding_mask, mask=False, layer=None, corpus_key=None
802
+ ):
803
+ res = self.forward(
804
+ source,
805
+ padding_mask,
806
+ mask=mask,
807
+ features_only=True,
808
+ layer=layer,
809
+ corpus_key=corpus_key,
810
+ )
811
+ return res
812
+
813
+ def get_logits(self, net_output):
814
+ logits = net_output["x"]
815
+ logits = logits.transpose(0, 2)
816
+ logits = logits.reshape(-1, logits.size(-1))
817
+ return logits
818
+
819
+ def get_targets(self, sample, net_output, expand_steps=True):
820
+ x = net_output["x"]
821
+ return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long)
822
+
823
+ def get_extra_losses(self, net_output):
824
+ pen = []
825
+
826
+ if "prob_perplexity" in net_output:
827
+ pen.append(
828
+ (net_output["num_vars"] - net_output["prob_perplexity"])
829
+ / net_output["num_vars"]
830
+ )
831
+
832
+ if "features_pen" in net_output:
833
+ pen.append(net_output["features_pen"])
834
+
835
+ return pen
836
+
837
+ def remove_pretraining_modules(self, last_layer=None):
838
+ self.quantizer = None
839
+ self.project_q = None
840
+ self.target_glu = None
841
+ self.final_proj = None
842
+
843
+ if last_layer is not None:
844
+ self.encoder.layers = nn.ModuleList(
845
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
846
+ )
847
+
848
+
849
+ class ConvFeatureExtractionModel(nn.Module):
850
+ def __init__(
851
+ self,
852
+ conv_layers: List[Tuple[int, int, int]],
853
+ dropout: float = 0.0,
854
+ mode: str = "default",
855
+ conv_bias: bool = False,
856
+ input_feature_ndim: int = 40
857
+ ):
858
+ super().__init__()
859
+
860
+ assert mode in {"default", "layer_norm"}
861
+
862
+ def block(
863
+ n_in,
864
+ n_out,
865
+ k,
866
+ stride,
867
+ is_layer_norm=False,
868
+ is_group_norm=False,
869
+ conv_bias=False,
870
+ ):
871
+ def make_conv():
872
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
873
+ nn.init.kaiming_normal_(conv.weight)
874
+ return conv
875
+
876
+ assert (
877
+ is_layer_norm and is_group_norm
878
+ ) == False, "layer norm and group norm are exclusive"
879
+
880
+ if is_layer_norm:
881
+ return nn.Sequential(
882
+ make_conv(),
883
+ nn.Dropout(p=dropout),
884
+ nn.Sequential(
885
+ TransposeLast(),
886
+ Fp32LayerNorm(dim, elementwise_affine=True),
887
+ TransposeLast(),
888
+ ),
889
+ nn.GELU(),
890
+ )
891
+ elif is_group_norm:
892
+ return nn.Sequential(
893
+ make_conv(),
894
+ nn.Dropout(p=dropout),
895
+ Fp32GroupNorm(dim, dim, affine=True),
896
+ nn.GELU(),
897
+ )
898
+ else:
899
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
900
+
901
+ in_d = input_feature_ndim
902
+ self.conv_layers = nn.ModuleList()
903
+ for i, cl in enumerate(conv_layers):
904
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
905
+ (dim, k, stride) = cl
906
+
907
+ self.conv_layers.append(
908
+ block(
909
+ in_d,
910
+ dim,
911
+ k,
912
+ stride,
913
+ is_layer_norm=mode == "layer_norm",
914
+ is_group_norm=mode == "default" and i == 0,
915
+ conv_bias=conv_bias,
916
+ )
917
+ )
918
+ in_d = dim
919
+
920
+ def forward(self, x):
921
+
922
+ # BxTxC -> BxCxT
923
+ #x = x.unsqueeze(1)
924
+ x = x.permute([0,2,1])
925
+
926
+ for conv in self.conv_layers:
927
+ x = conv(x)
928
+
929
+ return x
930
+
931
+
932
+ def make_conv_pos(e, k, g, is_batch_norm=False):
933
+ pos_conv = nn.Conv1d(
934
+ e,
935
+ e,
936
+ kernel_size=k,
937
+ padding=k // 2,
938
+ groups=g,
939
+ )
940
+ dropout = 0
941
+ std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
942
+ nn.init.normal_(pos_conv.weight, mean=0, std=std)
943
+ nn.init.constant_(pos_conv.bias, 0)
944
+
945
+ if not is_batch_norm:
946
+ pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
947
+ pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
948
+ else:
949
+ batch_norm = nn.BatchNorm1d(e)
950
+ pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU())
951
+
952
+ return pos_conv
953
+
954
+
955
+ class TransformerEncoder(nn.Module):
956
+ def build_encoder_layer(self, args: Wav2Vec2Config, **kwargs):
957
+ if args.layer_type == "transformer":
958
+ layer = TransformerSentenceEncoderLayer(
959
+ embedding_dim=self.embedding_dim,
960
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
961
+ num_attention_heads=args.encoder_attention_heads,
962
+ dropout=self.dropout,
963
+ attention_dropout=args.attention_dropout,
964
+ activation_dropout=args.activation_dropout,
965
+ activation_fn=args.activation_fn,
966
+ layer_norm_first=args.layer_norm_first,
967
+ )
968
+ elif args.layer_type == "conformer":
969
+ layer = ConformerWav2Vec2EncoderLayer(
970
+ embed_dim=self.embedding_dim,
971
+ ffn_embed_dim=args.encoder_ffn_embed_dim,
972
+ attention_heads=args.encoder_attention_heads,
973
+ dropout=args.dropout,
974
+ depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
975
+ activation_fn="swish",
976
+ attn_type=args.attn_type,
977
+ use_fp16=args.fp16,
978
+ pos_enc_type="abs",
979
+ )
980
+ elif args.layer_type == "trf_adp":
981
+ use_adp = False
982
+ if args.adp_trf_idx == "all":
983
+ use_adp = True
984
+ else:
985
+ adp_trf_idx = list(range(*[int(g) for g in args.adp_trf_idx.split(":")]))
986
+ if kwargs.get("layer_idx", None) in adp_trf_idx:
987
+ use_adp = True
988
+ if use_adp:
989
+ layer = TransformerSentenceEncoderWithAdapterLayer(
990
+ embedding_dim=self.embedding_dim,
991
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
992
+ num_attention_heads=args.encoder_attention_heads,
993
+ dropout=self.dropout,
994
+ attention_dropout=args.attention_dropout,
995
+ activation_dropout=args.activation_dropout,
996
+ activation_fn=args.activation_fn,
997
+ layer_norm_first=args.layer_norm_first,
998
+ adapter_num=args.adp_num,
999
+ adapter_dim=args.adp_dim,
1000
+ adapter_act_fn=args.adp_act_fn,
1001
+ )
1002
+ else:
1003
+ layer = TransformerSentenceEncoderLayer(
1004
+ embedding_dim=self.embedding_dim,
1005
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
1006
+ num_attention_heads=args.encoder_attention_heads,
1007
+ dropout=self.dropout,
1008
+ attention_dropout=args.attention_dropout,
1009
+ activation_dropout=args.activation_dropout,
1010
+ activation_fn=args.activation_fn,
1011
+ layer_norm_first=args.layer_norm_first,
1012
+ )
1013
+
1014
+ layer = fsdp_wrap(layer)
1015
+ if args.checkpoint_activations:
1016
+ layer = checkpoint_wrapper(layer)
1017
+ return layer
1018
+
1019
+ def __init__(self, args: Wav2Vec2Config):
1020
+ super().__init__()
1021
+
1022
+ self.dropout = args.dropout
1023
+ self.embedding_dim = args.encoder_embed_dim
1024
+ self.required_seq_len_multiple = args.required_seq_len_multiple
1025
+
1026
+ pos_conv_depth = getattr(args, "pos_conv_depth", 1)
1027
+ if pos_conv_depth > 1:
1028
+ num_layers = args.pos_conv_depth
1029
+ k = max(3, args.conv_pos // num_layers)
1030
+
1031
+ def make_conv_block(e, k, g, l):
1032
+ return nn.Sequential(
1033
+ *[
1034
+ nn.Sequential(
1035
+ nn.Conv1d(
1036
+ e,
1037
+ e,
1038
+ kernel_size=k,
1039
+ padding=k // 2,
1040
+ groups=g,
1041
+ ),
1042
+ SamePad(k),
1043
+ TransposeLast(),
1044
+ LayerNorm(e, elementwise_affine=False),
1045
+ TransposeLast(),
1046
+ nn.GELU(),
1047
+ )
1048
+ for _ in range(l)
1049
+ ]
1050
+ )
1051
+
1052
+ self.pos_conv = make_conv_block(
1053
+ self.embedding_dim, k, args.conv_pos_groups, num_layers
1054
+ )
1055
+
1056
+ else:
1057
+ self.pos_conv = make_conv_pos(
1058
+ self.embedding_dim,
1059
+ args.conv_pos,
1060
+ args.conv_pos_groups,
1061
+ is_batch_norm=args.conv_pos_batch_norm
1062
+ if hasattr(args, "conv_pos_batch_norm")
1063
+ else False,
1064
+ )
1065
+
1066
+ self.layers = nn.ModuleList(
1067
+ [self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)]
1068
+ )
1069
+ self.layer_norm_first = args.layer_norm_first
1070
+ self.layer_norm = LayerNorm(self.embedding_dim)
1071
+ self.layerdrop = args.encoder_layerdrop
1072
+
1073
+ self.apply(init_bert_params)
1074
+
1075
+ def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
1076
+ x, layer_results = self.extract_features(
1077
+ x, padding_mask, layer, corpus_key=corpus_key
1078
+ )
1079
+
1080
+ if self.layer_norm_first and layer is None:
1081
+ x = self.layer_norm(x)
1082
+
1083
+ return x, layer_results
1084
+
1085
+ def extract_features(
1086
+ self,
1087
+ x,
1088
+ padding_mask=None,
1089
+ tgt_layer=None,
1090
+ min_layer=0,
1091
+ corpus_key=None,
1092
+ ):
1093
+
1094
+ if padding_mask is not None:
1095
+ x = index_put(x, padding_mask, 0)
1096
+
1097
+ x_conv = self.pos_conv(x.transpose(1, 2))
1098
+ x_conv = x_conv.transpose(1, 2)
1099
+ x = x + x_conv
1100
+
1101
+ if not self.layer_norm_first:
1102
+ x = self.layer_norm(x)
1103
+
1104
+ # pad to the sequence length dimension
1105
+ x, pad_length = pad_to_multiple(
1106
+ x, self.required_seq_len_multiple, dim=-2, value=0
1107
+ )
1108
+ if pad_length > 0 and padding_mask is None:
1109
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
1110
+ padding_mask[:, -pad_length:] = True
1111
+ else:
1112
+ padding_mask, _ = pad_to_multiple(
1113
+ padding_mask, self.required_seq_len_multiple, dim=-1, value=True
1114
+ )
1115
+ x = F.dropout(x, p=self.dropout, training=self.training)
1116
+
1117
+ # B x T x C -> T x B x C
1118
+ x = x.transpose(0, 1)
1119
+
1120
+ layer_results = []
1121
+ r = None
1122
+
1123
+ for i, layer in enumerate(self.layers):
1124
+ dropout_probability = np.random.random() if self.layerdrop > 0 else 1
1125
+ if not self.training or (dropout_probability > self.layerdrop):
1126
+ layer_check = layer
1127
+ if isinstance(layer, FullyShardedDataParallel):
1128
+ layer_check = layer.unwrapped_module
1129
+ if (corpus_key is None) or (
1130
+ not isinstance(layer_check, (
1131
+ TransformerSentenceEncoderWithAdapterLayer,
1132
+ )
1133
+ )
1134
+ ):
1135
+ x, (z, lr) = layer(
1136
+ x, self_attn_padding_mask=padding_mask, need_weights=False
1137
+ )
1138
+ else:
1139
+ x, (z, lr) = layer(
1140
+ x,
1141
+ self_attn_padding_mask=padding_mask,
1142
+ need_weights=False,
1143
+ corpus_key=corpus_key,
1144
+ )
1145
+ if i >= min_layer:
1146
+ layer_results.append((x, z, lr))
1147
+ if i == tgt_layer:
1148
+ r = x
1149
+ break
1150
+
1151
+ if r is not None:
1152
+ x = r
1153
+
1154
+ # T x B x C -> B x T x C
1155
+ x = x.transpose(0, 1)
1156
+
1157
+ # undo paddding
1158
+ if pad_length > 0:
1159
+ x = x[:, :-pad_length]
1160
+
1161
+ def undo_pad(a, b, c):
1162
+ return (
1163
+ a[:-pad_length],
1164
+ b[:-pad_length] if b is not None else b,
1165
+ c[:-pad_length],
1166
+ )
1167
+
1168
+ layer_results = [undo_pad(*u) for u in layer_results]
1169
+
1170
+ return x, layer_results
1171
+
1172
+ def max_positions(self):
1173
+ """Maximum output length supported by the encoder."""
1174
+ return self.args.max_positions
1175
+
1176
+ def upgrade_state_dict_named(self, state_dict, name):
1177
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
1178
+ return state_dict
1179
+
1180
+
1181
+ class ConformerEncoder(TransformerEncoder):
1182
+ def build_encoder_layer(self, args):
1183
+ layer = ConformerWav2Vec2EncoderLayer(
1184
+ embed_dim=self.embedding_dim,
1185
+ ffn_embed_dim=args.encoder_ffn_embed_dim,
1186
+ attention_heads=args.encoder_attention_heads,
1187
+ dropout=args.dropout,
1188
+ depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
1189
+ activation_fn="swish",
1190
+ attn_type=args.attn_type,
1191
+ pos_enc_type=args.pos_enc_type,
1192
+ use_fp16=args.fp16, # only used for rope
1193
+ )
1194
+ layer = fsdp_wrap(layer)
1195
+ if args.checkpoint_activations:
1196
+ layer = checkpoint_wrapper(layer)
1197
+ return layer
1198
+
1199
+ def __init__(self, args):
1200
+ super().__init__(args)
1201
+ self.args = args
1202
+ self.dropout = args.dropout
1203
+ self.embedding_dim = args.encoder_embed_dim
1204
+ self.pos_enc_type = args.pos_enc_type
1205
+ max_source_positions = self.max_positions()
1206
+
1207
+ if self.pos_enc_type == "rel_pos":
1208
+ self.embed_positions = RelPositionalEncoding(
1209
+ max_source_positions, self.embedding_dim
1210
+ )
1211
+ elif self.pos_enc_type == "rope":
1212
+ self.embed_positions = None
1213
+ else:
1214
+ raise Exception("Unsupported positional encoding type")
1215
+
1216
+ self.layers = nn.ModuleList(
1217
+ [self.build_encoder_layer(args) for _ in range(args.encoder_layers)]
1218
+ )
1219
+ self.layer_norm_first = args.layer_norm_first
1220
+ self.layer_norm = LayerNorm(self.embedding_dim)
1221
+ self.layerdrop = args.encoder_layerdrop
1222
+
1223
+ self.apply(init_bert_params)
1224
+
1225
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
1226
+ if padding_mask is not None:
1227
+ x = index_put(x, padding_mask, 0)
1228
+
1229
+ # B x T x C -> T x B x C
1230
+ x = x.transpose(0, 1)
1231
+
1232
+ # B X T X C here
1233
+ position_emb = None
1234
+ if self.pos_enc_type == "rel_pos":
1235
+ position_emb = self.embed_positions(x)
1236
+
1237
+ if not self.layer_norm_first:
1238
+ x = self.layer_norm(x)
1239
+
1240
+ x = F.dropout(x, p=self.dropout, training=self.training)
1241
+
1242
+ layer_results = []
1243
+ r = None
1244
+ for i, layer in enumerate(self.layers):
1245
+ dropout_probability = np.random.random()
1246
+ if not self.training or (dropout_probability > self.layerdrop):
1247
+ x, z = layer(
1248
+ x,
1249
+ self_attn_padding_mask=padding_mask,
1250
+ need_weights=False,
1251
+ position_emb=position_emb,
1252
+ )
1253
+ if tgt_layer is not None:
1254
+ layer_results.append((x, z))
1255
+ if i == tgt_layer:
1256
+ r = x
1257
+ break
1258
+
1259
+ if r is not None:
1260
+ x = r
1261
+
1262
+ # T x B x C -> B x T x C
1263
+ x = x.transpose(0, 1)
1264
+
1265
+ return x, layer_results
1266
+
1267
+
1268
+ class TransformerSentenceEncoderLayer(nn.Module):
1269
+ """
1270
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
1271
+ models.
1272
+ """
1273
+
1274
+ def __init__(
1275
+ self,
1276
+ embedding_dim: float = 768,
1277
+ ffn_embedding_dim: float = 3072,
1278
+ num_attention_heads: int = 8,
1279
+ dropout: float = 0.1,
1280
+ attention_dropout: float = 0.1,
1281
+ activation_dropout: float = 0.1,
1282
+ activation_fn: str = "relu",
1283
+ layer_norm_first: bool = False,
1284
+ ) -> None:
1285
+
1286
+ super().__init__()
1287
+ # Initialize parameters
1288
+ self.embedding_dim = embedding_dim
1289
+ self.dropout = dropout
1290
+ self.activation_dropout = activation_dropout
1291
+
1292
+ # Initialize blocks
1293
+ self.activation_fn = utils.get_activation_fn(activation_fn)
1294
+ self.self_attn = MultiheadAttention(
1295
+ self.embedding_dim,
1296
+ num_attention_heads,
1297
+ dropout=attention_dropout,
1298
+ self_attention=True,
1299
+ )
1300
+
1301
+ self.dropout1 = nn.Dropout(dropout)
1302
+ self.dropout2 = nn.Dropout(self.activation_dropout)
1303
+ self.dropout3 = nn.Dropout(dropout)
1304
+
1305
+ self.layer_norm_first = layer_norm_first
1306
+
1307
+ # layer norm associated with the self attention layer
1308
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
1309
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
1310
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
1311
+
1312
+ # layer norm associated with the position wise feed-forward NN
1313
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
1314
+
1315
+ def forward(
1316
+ self,
1317
+ x: torch.Tensor,
1318
+ self_attn_mask: torch.Tensor = None,
1319
+ self_attn_padding_mask: torch.Tensor = None,
1320
+ need_weights: bool = False,
1321
+ att_args=None,
1322
+ ):
1323
+ """
1324
+ LayerNorm is applied either before or after the self-attention/ffn
1325
+ modules similar to the original Transformer imlementation.
1326
+ """
1327
+ residual = x
1328
+
1329
+ if self.layer_norm_first:
1330
+ x = self.self_attn_layer_norm(x)
1331
+ x, attn = self.self_attn(
1332
+ query=x,
1333
+ key=x,
1334
+ value=x,
1335
+ key_padding_mask=self_attn_padding_mask,
1336
+ attn_mask=self_attn_mask,
1337
+ need_weights=False,
1338
+ )
1339
+ x = self.dropout1(x)
1340
+ x = residual + x
1341
+
1342
+ residual = x
1343
+ x = self.final_layer_norm(x)
1344
+ x = self.activation_fn(self.fc1(x))
1345
+ x = self.dropout2(x)
1346
+ x = self.fc2(x)
1347
+
1348
+ layer_result = x
1349
+
1350
+ x = self.dropout3(x)
1351
+ x = residual + x
1352
+ else:
1353
+ x, attn = self.self_attn(
1354
+ query=x,
1355
+ key=x,
1356
+ value=x,
1357
+ key_padding_mask=self_attn_padding_mask,
1358
+ need_weights=False,
1359
+ )
1360
+
1361
+ x = self.dropout1(x)
1362
+ x = residual + x
1363
+
1364
+ x = self.self_attn_layer_norm(x)
1365
+
1366
+ residual = x
1367
+ x = self.activation_fn(self.fc1(x))
1368
+ x = self.dropout2(x)
1369
+ x = self.fc2(x)
1370
+
1371
+ layer_result = x
1372
+
1373
+ x = self.dropout3(x)
1374
+ x = residual + x
1375
+ x = self.final_layer_norm(x)
1376
+
1377
+ return x, (attn, layer_result)
1378
+
1379
+
1380
+ class AdapterFast(nn.Module):
1381
+ def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
1382
+ """
1383
+ Implements adapter modules directly with 3D tensor weight as parameters
1384
+ and without using ModuleList orto speed up training throughput.
1385
+ """
1386
+ super().__init__()
1387
+
1388
+ self.adapter_num = adapter_num
1389
+ self.input_dim = input_dim
1390
+ self.hidden_dim = hidden_dim
1391
+ self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
1392
+ self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
1393
+ self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
1394
+ self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
1395
+
1396
+ self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
1397
+ self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
1398
+ self.act_fn = nn.Identity()
1399
+ if act_fn == "relu":
1400
+ self.act_fn = nn.ReLU()
1401
+ elif act_fn == "gelu":
1402
+ self.act_fn = nn.GELU()
1403
+ elif act_fn == "selu":
1404
+ self.act_fn = nn.SELU()
1405
+ else:
1406
+ raise ValueError(f"unsupported {act_fn}")
1407
+
1408
+
1409
+ self.input_dim = input_dim
1410
+ self.reset_parameters()
1411
+
1412
+ def reset_parameters(self):
1413
+ for ii in range(self.adapter_num):
1414
+ nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
1415
+ nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
1416
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
1417
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
1418
+ nn.init.uniform_(self.b_a[ii], -bound, bound)
1419
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
1420
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
1421
+ nn.init.uniform_(self.b_b[ii], -bound, bound)
1422
+
1423
+ nn.init.ones_(self.ln_W)
1424
+ nn.init.zeros_(self.ln_b)
1425
+
1426
+ def forward(self, x, adapter_id):
1427
+ ii = adapter_id
1428
+ h = x
1429
+ h = F.layer_norm(h, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii])
1430
+ h = F.linear(h, self.W_a[ii], self.b_a[ii])
1431
+ h = self.act_fn(h)
1432
+ h = F.linear(h, self.W_b[ii], self.b_b[ii])
1433
+ outputs = h
1434
+ return outputs
1435
+
1436
+ def extra_repr(self):
1437
+ return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
1438
+
1439
+
1440
+
1441
+ class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
1442
+ """
1443
+ Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained
1444
+ models. An adapter module is added along with vanilla Transformer module.
1445
+ """
1446
+
1447
+ def __init__(
1448
+ self,
1449
+ embedding_dim: float = 768,
1450
+ ffn_embedding_dim: float = 3072,
1451
+ num_attention_heads: int = 8,
1452
+ dropout: float = 0.1,
1453
+ attention_dropout: float = 0.1,
1454
+ activation_dropout: float = 0.1,
1455
+ activation_fn: str = "relu",
1456
+ layer_norm_first: bool = False,
1457
+ adapter_num=201,
1458
+ adapter_dim=64,
1459
+ adapter_act_fn="relu",
1460
+ ) -> None:
1461
+
1462
+ super().__init__(
1463
+ embedding_dim=embedding_dim,
1464
+ ffn_embedding_dim=ffn_embedding_dim,
1465
+ num_attention_heads=num_attention_heads,
1466
+ dropout=dropout,
1467
+ attention_dropout=attention_dropout,
1468
+ activation_dropout=activation_dropout,
1469
+ activation_fn=activation_fn,
1470
+ layer_norm_first=layer_norm_first,
1471
+
1472
+ )
1473
+
1474
+ self.adapter_num = adapter_num
1475
+ self.adapter_dim = adapter_dim
1476
+ self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
1477
+
1478
+ def forward(
1479
+ self,
1480
+ x: torch.Tensor,
1481
+ self_attn_mask: torch.Tensor = None,
1482
+ self_attn_padding_mask: torch.Tensor = None,
1483
+ need_weights: bool = False,
1484
+ att_args=None,
1485
+ corpus_key=None,
1486
+ ):
1487
+
1488
+ x, (attn, layer_result) = super().forward(
1489
+ x=x,
1490
+ self_attn_mask=self_attn_mask,
1491
+ self_attn_padding_mask=self_attn_padding_mask,
1492
+ need_weights=need_weights,
1493
+ att_args=att_args,
1494
+ )
1495
+ assert corpus_key is not None
1496
+ assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}"
1497
+ y = self.adapter_layer(x, corpus_key[0])
1498
+ x = x + y
1499
+ return x, (attn, layer_result)