jasspier commited on
Commit
542b1ba
1 Parent(s): ac47f83

Create data2vec2.py

Browse files
Files changed (1) hide show
  1. data2vec2.py +815 -0
data2vec2.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import math
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional, Callable
10
+ from functools import partial
11
+ import numpy as np
12
+
13
+ from omegaconf import II
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.distributed as dist
19
+
20
+ from fairseq.modules import EMAModule, EMAModuleConfig
21
+
22
+ from fairseq.dataclass import FairseqDataclass
23
+ from fairseq.models import BaseFairseqModel, register_model
24
+
25
+ from examples.data2vec.data.modality import Modality
26
+
27
+ from examples.data2vec.models.modalities.base import (
28
+ MaskSeed,
29
+ D2vModalityConfig,
30
+ ModalitySpecificEncoder,
31
+ get_annealed_rate,
32
+ )
33
+ from examples.data2vec.models.modalities.modules import (
34
+ D2vDecoderConfig,
35
+ AltBlock,
36
+ Decoder1d,
37
+ )
38
+
39
+ from .modalities.audio import (
40
+ D2vAudioConfig,
41
+ AudioEncoder,
42
+ )
43
+ from examples.data2vec.models.modalities.images import (
44
+ D2vImageConfig,
45
+ ImageEncoder,
46
+ )
47
+ from examples.data2vec.models.modalities.text import (
48
+ D2vTextConfig,
49
+ TextEncoder,
50
+ )
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ @dataclass
56
+ class D2vModalitiesConfig(FairseqDataclass):
57
+ audio: D2vAudioConfig = D2vAudioConfig()
58
+ image: D2vImageConfig = D2vImageConfig()
59
+ text: D2vTextConfig = D2vTextConfig()
60
+
61
+
62
+ @dataclass
63
+ class Data2VecMultiConfig(FairseqDataclass):
64
+
65
+ loss_beta: float = field(
66
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
67
+ )
68
+ loss_scale: Optional[float] = field(
69
+ default=None,
70
+ metadata={
71
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
72
+ },
73
+ )
74
+
75
+ input_feature_ndim: int = 40
76
+ depth: int = 8
77
+ start_drop_path_rate: float = 0
78
+ end_drop_path_rate: float = 0
79
+ num_heads: int = 12
80
+ norm_eps: float = 1e-6
81
+ norm_affine: bool = True
82
+ encoder_dropout: float = 0.1
83
+ post_mlp_drop: float = 0.1
84
+ attention_dropout: float = 0.1
85
+ activation_dropout: float = 0.0
86
+ dropout_input: float = 0.0
87
+ layerdrop: float = 0.0
88
+ embed_dim: int = 768
89
+ mlp_ratio: float = 4
90
+ layer_norm_first: bool = False
91
+
92
+ average_top_k_layers: int = field(
93
+ default=8, metadata={"help": "how many layers to average"}
94
+ )
95
+
96
+ end_of_block_targets: bool = False
97
+
98
+ clone_batch: int = 1
99
+
100
+ layer_norm_target_layer: bool = False
101
+ batch_norm_target_layer: bool = False
102
+ instance_norm_target_layer: bool = False
103
+ instance_norm_targets: bool = False
104
+ layer_norm_targets: bool = False
105
+
106
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
107
+ ema_same_dtype: bool = True
108
+ log_norms: bool = True
109
+ ema_end_decay: float = field(
110
+ default=0.9999, metadata={"help": "final ema decay rate"}
111
+ )
112
+
113
+ # when to finish annealing ema decay rate
114
+ ema_anneal_end_step: int = II("optimization.max_update")
115
+
116
+ ema_encoder_only: bool = field(
117
+ default=True,
118
+ metadata={
119
+ "help": "whether to momentum update only the shared transformer encoder"
120
+ },
121
+ )
122
+
123
+ max_update: int = II("optimization.max_update")
124
+
125
+ modalities: D2vModalitiesConfig = D2vModalitiesConfig()
126
+
127
+ shared_decoder: Optional[D2vDecoderConfig] = None
128
+
129
+ min_target_var: float = field(
130
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
131
+ )
132
+ min_pred_var: float = field(
133
+ default=0.01,
134
+ metadata={"help": "stop training if prediction var falls below this"},
135
+ )
136
+
137
+ supported_modality: Optional[Modality] = None
138
+ mae_init: bool = False
139
+
140
+ seed: int = II("common.seed")
141
+
142
+ skip_ema: bool = False
143
+
144
+ cls_loss: float = 0
145
+ recon_loss: float = 0
146
+ d2v_loss: float = 1
147
+
148
+ decoder_group: bool = False
149
+
150
+
151
+ @register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
152
+ class Data2VecMultiModel(BaseFairseqModel):
153
+ def make_modality_encoder(
154
+ self,
155
+ cfg: D2vModalityConfig,
156
+ embed_dim: int,
157
+ make_block: Callable[[float], nn.ModuleList],
158
+ norm_layer: Callable[[int], nn.LayerNorm],
159
+ layer_norm_first: bool,
160
+ alibi_biases,
161
+ task,
162
+ ) -> ModalitySpecificEncoder:
163
+ if cfg.type == Modality.AUDIO:
164
+ enc_cls = AudioEncoder
165
+ elif cfg.type == Modality.IMAGE:
166
+ enc_cls = ImageEncoder
167
+ elif cfg.type == Modality.TEXT:
168
+ enc_cls = TextEncoder
169
+ if hasattr(task, "text_task"):
170
+ task = task.text_task
171
+ else:
172
+ raise Exception(f"unsupported modality {cfg.type}")
173
+
174
+ return enc_cls(
175
+ cfg,
176
+ embed_dim,
177
+ make_block,
178
+ norm_layer,
179
+ layer_norm_first,
180
+ alibi_biases,
181
+ task,
182
+ )
183
+
184
+ def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None):
185
+ super().__init__()
186
+ self.cfg = cfg
187
+ self.modalities = modalities
188
+ self.task = task
189
+
190
+ make_layer_norm = partial(
191
+ nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
192
+ )
193
+
194
+ def make_block(drop_path, dim=None, heads=None):
195
+ return AltBlock(
196
+ cfg.embed_dim if dim is None else dim,
197
+ cfg.num_heads if heads is None else heads,
198
+ cfg.mlp_ratio,
199
+ qkv_bias=True,
200
+ drop=cfg.encoder_dropout,
201
+ attn_drop=cfg.attention_dropout,
202
+ mlp_drop=cfg.activation_dropout,
203
+ post_mlp_drop=cfg.post_mlp_drop,
204
+ drop_path=drop_path,
205
+ norm_layer=make_layer_norm,
206
+ layer_norm_first=cfg.layer_norm_first,
207
+ ffn_targets=not cfg.end_of_block_targets,
208
+ )
209
+
210
+ self.alibi_biases = {}
211
+ self.modality_encoders = nn.ModuleDict()
212
+ for mod in self.modalities:
213
+ mod_cfg = getattr(cfg.modalities, mod.name.lower())
214
+ enc = self.make_modality_encoder(
215
+ mod_cfg,
216
+ cfg.embed_dim,
217
+ make_block,
218
+ make_layer_norm,
219
+ cfg.layer_norm_first,
220
+ self.alibi_biases,
221
+ task,
222
+ )
223
+ self.modality_encoders[mod.name] = enc
224
+
225
+ self.ema = None
226
+
227
+ self.average_top_k_layers = cfg.average_top_k_layers
228
+ self.loss_beta = cfg.loss_beta
229
+ self.loss_scale = cfg.loss_scale
230
+
231
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
232
+
233
+ dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
234
+
235
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
236
+
237
+ self.norm = None
238
+ if cfg.layer_norm_first:
239
+ self.norm = make_layer_norm(cfg.embed_dim)
240
+
241
+ if self.cfg.mae_init:
242
+ self.apply(self._init_weights)
243
+ else:
244
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
245
+
246
+ self.apply(init_bert_params)
247
+
248
+ for mod_enc in self.modality_encoders.values():
249
+ mod_enc.reset_parameters()
250
+
251
+ if not skip_ema:
252
+ self.ema = self.make_ema_teacher(cfg.ema_decay)
253
+ self.shared_decoder = (
254
+ Decoder1d(cfg.shared_decoder, cfg.embed_dim)
255
+ if self.cfg.shared_decoder is not None
256
+ else None
257
+ )
258
+ if self.shared_decoder is not None:
259
+ self.shared_decoder.apply(self._init_weights)
260
+
261
+ self.recon_proj = None
262
+ if cfg.recon_loss > 0:
263
+ self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
264
+
265
+ for pn, p in self.named_parameters():
266
+ if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
267
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
268
+ if cfg.decoder_group and "decoder" in pn:
269
+ p.param_group = "decoder"
270
+
271
+ self.num_updates = 0
272
+
273
+ def _init_weights(self, m):
274
+
275
+ try:
276
+ from apex.normalization import FusedLayerNorm
277
+
278
+ fn = FusedLayerNorm
279
+ except:
280
+ fn = nn.LayerNorm
281
+
282
+ if isinstance(m, nn.Linear):
283
+ torch.nn.init.xavier_uniform_(m.weight)
284
+ if isinstance(m, nn.Linear) and m.bias is not None:
285
+ nn.init.constant_(m.bias, 0)
286
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
287
+ if m.bias is not None:
288
+ nn.init.constant_(m.bias, 0)
289
+ if m.weight is not None:
290
+ nn.init.constant_(m.weight, 1.0)
291
+
292
+ @torch.no_grad()
293
+ def make_ema_teacher(self, ema_decay):
294
+ ema_config = EMAModuleConfig(
295
+ ema_decay=ema_decay,
296
+ ema_fp32=True,
297
+ log_norms=self.cfg.log_norms,
298
+ add_missing_params=False,
299
+ )
300
+
301
+ model_copy = self.make_target_model()
302
+
303
+ return EMAModule(
304
+ model_copy,
305
+ ema_config,
306
+ copy_model=False,
307
+ )
308
+
309
+ def make_target_model(self):
310
+ logger.info("making target model")
311
+
312
+ model_copy = Data2VecMultiModel(
313
+ self.cfg, self.modalities, skip_ema=True, task=self.task
314
+ )
315
+
316
+ if self.cfg.ema_encoder_only:
317
+ model_copy = model_copy.blocks
318
+ for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
319
+ p_t.data.copy_(p_s.data)
320
+ else:
321
+ for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
322
+ p_t.data.copy_(p_s.data)
323
+
324
+ for mod_enc in model_copy.modality_encoders.values():
325
+ mod_enc.decoder = None
326
+ if not mod_enc.modality_cfg.ema_local_encoder:
327
+ mod_enc.local_encoder = None
328
+ mod_enc.project_features = None
329
+
330
+ model_copy.requires_grad_(False)
331
+ return model_copy
332
+
333
+ def set_num_updates(self, num_updates):
334
+ super().set_num_updates(num_updates)
335
+
336
+ if self.ema is not None and (
337
+ (self.num_updates == 0 and num_updates > 1)
338
+ or self.num_updates >= num_updates
339
+ ):
340
+ pass
341
+ elif self.training and self.ema is not None:
342
+ ema_weight_decay = None
343
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
344
+ if num_updates >= self.cfg.ema_anneal_end_step:
345
+ decay = self.cfg.ema_end_decay
346
+ else:
347
+ decay = get_annealed_rate(
348
+ self.cfg.ema_decay,
349
+ self.cfg.ema_end_decay,
350
+ num_updates,
351
+ self.cfg.ema_anneal_end_step,
352
+ )
353
+ self.ema.set_decay(decay, weight_decay=ema_weight_decay)
354
+ if self.ema.get_decay() < 1:
355
+ self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
356
+
357
+ self.num_updates = num_updates
358
+
359
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
360
+ state = super().state_dict(destination, prefix, keep_vars)
361
+
362
+ if self.ema is not None:
363
+ state[prefix + "_ema"] = self.ema.fp32_params
364
+
365
+ return state
366
+
367
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
368
+ k = prefix + "_ema"
369
+ if self.ema is not None:
370
+ assert k in state_dict
371
+ self.ema.restore(state_dict[k], True)
372
+ del state_dict[k]
373
+ elif k in state_dict:
374
+ del state_dict[k]
375
+
376
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
377
+
378
+ @classmethod
379
+ def build_model(cls, cfg: Data2VecMultiConfig, task=None):
380
+ """Build a new model instance."""
381
+ if task is None or not hasattr(task, "supported_modalities"):
382
+ modalities = (
383
+ [cfg.supported_modality]
384
+ if cfg.supported_modality is not None
385
+ else [
386
+ Modality.AUDIO,
387
+ Modality.IMAGE,
388
+ Modality.TEXT,
389
+ ]
390
+ )
391
+ else:
392
+ modalities = task.supported_modalities
393
+ return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)
394
+
395
+ def forward(
396
+ self,
397
+ source,
398
+ target=None,
399
+ id=None,
400
+ mode=None,
401
+ padding_mask=None,
402
+ mask=True,
403
+ features_only=False,
404
+ force_remove_masked=False,
405
+ remove_extra_tokens=True,
406
+ precomputed_mask=None,
407
+ corpus_key=None, # for config compatiblity
408
+ ):
409
+ if mode is None:
410
+ assert self.cfg.supported_modality is not None
411
+ mode = self.cfg.supported_modality
412
+
413
+ if isinstance(mode, Modality):
414
+ mode = mode.name
415
+
416
+ feature_extractor = self.modality_encoders[mode]
417
+
418
+ mask_seeds = None
419
+ if id is not None:
420
+ mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)
421
+
422
+ extractor_out = feature_extractor(
423
+ source,
424
+ padding_mask,
425
+ mask,
426
+ remove_masked=not features_only or force_remove_masked,
427
+ clone_batch=self.cfg.clone_batch if not features_only else 1,
428
+ mask_seeds=mask_seeds,
429
+ precomputed_mask=precomputed_mask,
430
+ )
431
+
432
+ x = extractor_out["x"]
433
+ encoder_mask = extractor_out["encoder_mask"]
434
+ masked_padding_mask = extractor_out["padding_mask"]
435
+ masked_alibi_bias = extractor_out.get("alibi_bias", None)
436
+ alibi_scale = extractor_out.get("alibi_scale", None)
437
+
438
+ if self.dropout_input is not None:
439
+ x = self.dropout_input(x)
440
+
441
+ layer_results = []
442
+ for i, blk in enumerate(self.blocks):
443
+ if (
444
+ not self.training
445
+ or self.cfg.layerdrop == 0
446
+ or (np.random.random() > self.cfg.layerdrop)
447
+ ):
448
+ ab = masked_alibi_bias
449
+ if ab is not None and alibi_scale is not None:
450
+ scale = (
451
+ alibi_scale[i]
452
+ if alibi_scale.size(0) > 1
453
+ else alibi_scale.squeeze(0)
454
+ )
455
+ ab = ab * scale.type_as(ab)
456
+
457
+ x, lr = blk(
458
+ x,
459
+ padding_mask=masked_padding_mask,
460
+ alibi_bias=ab,
461
+ )
462
+ if features_only:
463
+ layer_results.append((x, lr))
464
+
465
+ if self.norm is not None:
466
+ x = self.norm(x)
467
+
468
+ if features_only:
469
+ if remove_extra_tokens:
470
+ x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
471
+ if masked_padding_mask is not None:
472
+ masked_padding_mask = masked_padding_mask[
473
+ :, feature_extractor.modality_cfg.num_extra_tokens :
474
+ ]
475
+
476
+ return {
477
+ "x": x,
478
+ "padding_mask": masked_padding_mask,
479
+ "layer_results": layer_results,
480
+ "mask": encoder_mask,
481
+ }
482
+
483
+ xs = []
484
+
485
+ if self.shared_decoder is not None:
486
+ dx = self.forward_decoder(
487
+ x,
488
+ feature_extractor,
489
+ self.shared_decoder,
490
+ encoder_mask,
491
+ )
492
+ xs.append(dx)
493
+ if feature_extractor.decoder is not None:
494
+ dx = self.forward_decoder(
495
+ x,
496
+ feature_extractor,
497
+ feature_extractor.decoder,
498
+ encoder_mask,
499
+ )
500
+ xs.append(dx)
501
+ orig_x = x
502
+
503
+ assert len(xs) > 0
504
+
505
+ p = next(self.ema.model.parameters())
506
+ device = x.device
507
+ dtype = x.dtype
508
+ ema_device = p.device
509
+ ema_dtype = p.dtype
510
+
511
+ if not self.cfg.ema_same_dtype:
512
+ dtype = ema_dtype
513
+
514
+ if ema_device != device or ema_dtype != dtype:
515
+ logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
516
+ self.ema.model = self.ema.model.to(dtype=dtype, device=device)
517
+ ema_dtype = dtype
518
+
519
+ def to_device(d):
520
+ for k, p in d.items():
521
+ if isinstance(d[k], dict):
522
+ to_device(d[k])
523
+ else:
524
+ d[k] = p.to(device=device)
525
+
526
+ to_device(self.ema.fp32_params)
527
+ tm = self.ema.model
528
+
529
+ with torch.no_grad():
530
+ tm.eval()
531
+
532
+ if self.cfg.ema_encoder_only:
533
+ assert target is None
534
+ ema_input = extractor_out["local_features"]
535
+ ema_input = feature_extractor.contextualized_features(
536
+ ema_input.to(dtype=ema_dtype),
537
+ padding_mask,
538
+ mask=False,
539
+ remove_masked=False,
540
+ )
541
+ ema_blocks = tm
542
+ else:
543
+ ema_blocks = tm.blocks
544
+ if feature_extractor.modality_cfg.ema_local_encoder:
545
+ inp = (
546
+ target.to(dtype=ema_dtype)
547
+ if target is not None
548
+ else source.to(dtype=ema_dtype)
549
+ )
550
+ ema_input = tm.modality_encoders[mode](
551
+ inp,
552
+ padding_mask,
553
+ mask=False,
554
+ remove_masked=False,
555
+ )
556
+ else:
557
+ assert target is None
558
+ ema_input = extractor_out["local_features"]
559
+ ema_feature_enc = tm.modality_encoders[mode]
560
+ ema_input = ema_feature_enc.contextualized_features(
561
+ ema_input.to(dtype=ema_dtype),
562
+ padding_mask,
563
+ mask=False,
564
+ remove_masked=False,
565
+ )
566
+
567
+ ema_padding_mask = ema_input["padding_mask"]
568
+ ema_alibi_bias = ema_input.get("alibi_bias", None)
569
+ ema_alibi_scale = ema_input.get("alibi_scale", None)
570
+ ema_input = ema_input["x"]
571
+
572
+ y = []
573
+ ema_x = []
574
+ extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
575
+ for i, blk in enumerate(ema_blocks):
576
+ ab = ema_alibi_bias
577
+ if ab is not None and alibi_scale is not None:
578
+ scale = (
579
+ ema_alibi_scale[i]
580
+ if ema_alibi_scale.size(0) > 1
581
+ else ema_alibi_scale.squeeze(0)
582
+ )
583
+ ab = ab * scale.type_as(ab)
584
+
585
+ ema_input, lr = blk(
586
+ ema_input,
587
+ padding_mask=ema_padding_mask,
588
+ alibi_bias=ab,
589
+ )
590
+ y.append(lr[:, extra_tokens:])
591
+ ema_x.append(ema_input[:, extra_tokens:])
592
+
593
+ y = self.make_targets(y, self.average_top_k_layers)
594
+ orig_targets = y
595
+
596
+ if self.cfg.clone_batch > 1:
597
+ y = y.repeat_interleave(self.cfg.clone_batch, 0)
598
+
599
+ masked = encoder_mask.mask.unsqueeze(-1)
600
+ masked_b = encoder_mask.mask.bool()
601
+ y = y[masked_b]
602
+
603
+ if xs[0].size(1) == masked_b.size(1):
604
+ xs = [x[masked_b] for x in xs]
605
+ else:
606
+ xs = [x.reshape(-1, x.size(-1)) for x in xs]
607
+
608
+ sample_size = masked.sum().long()
609
+
610
+ result = {
611
+ "losses": {},
612
+ "sample_size": sample_size,
613
+ }
614
+
615
+ sample_size = result["sample_size"]
616
+
617
+ if self.cfg.cls_loss > 0:
618
+ assert extra_tokens > 0
619
+ cls_target = orig_targets.mean(dim=1)
620
+ if self.cfg.clone_batch > 1:
621
+ cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
622
+ cls_pred = x[:, extra_tokens - 1]
623
+ result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
624
+ self.cfg.cls_loss * sample_size
625
+ )
626
+
627
+ if self.cfg.recon_loss > 0:
628
+
629
+ with torch.no_grad():
630
+ target = feature_extractor.patchify(source)
631
+ mean = target.mean(dim=-1, keepdim=True)
632
+ var = target.var(dim=-1, keepdim=True)
633
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
634
+
635
+ if self.cfg.clone_batch > 1:
636
+ target = target.repeat_interleave(self.cfg.clone_batch, 0)
637
+
638
+ if masked_b is not None:
639
+ target = target[masked_b]
640
+
641
+ recon = xs[0]
642
+ if self.recon_proj is not None:
643
+ recon = self.recon_proj(recon)
644
+
645
+ result["losses"]["recon"] = (
646
+ self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
647
+ )
648
+
649
+ if self.cfg.d2v_loss > 0:
650
+ for i, x in enumerate(xs):
651
+ reg_loss = self.d2v_loss(x, y)
652
+ n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
653
+ result["losses"][n] = reg_loss * self.cfg.d2v_loss
654
+
655
+ suffix = "" if len(self.modalities) == 1 else f"_{mode}"
656
+ with torch.no_grad():
657
+ if encoder_mask is not None:
658
+ result["masked_pct"] = 1 - (
659
+ encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
660
+ )
661
+ for i, x in enumerate(xs):
662
+ n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
663
+ result[n] = self.compute_var(x.float())
664
+ if self.ema is not None:
665
+ for k, v in self.ema.logs.items():
666
+ result[k] = v
667
+
668
+ y = y.float()
669
+ result[f"target_var{suffix}"] = self.compute_var(y)
670
+
671
+ if self.num_updates > 5000:
672
+ if result[f"target_var{suffix}"] < self.cfg.min_target_var:
673
+ logger.error(
674
+ f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
675
+ )
676
+ raise Exception(
677
+ f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
678
+ )
679
+
680
+ for k in result.keys():
681
+ if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
682
+ logger.error(
683
+ f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
684
+ )
685
+ raise Exception(
686
+ f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
687
+ )
688
+
689
+ result["ema_decay"] = self.ema.get_decay() * 1000
690
+
691
+ return result
692
+
693
+ def forward_decoder(
694
+ self,
695
+ x,
696
+ feature_extractor,
697
+ decoder,
698
+ mask_info,
699
+ ):
700
+ x = feature_extractor.decoder_input(x, mask_info)
701
+ x = decoder(*x)
702
+
703
+ return x
704
+
705
+ def d2v_loss(self, x, y):
706
+ x = x.view(-1, x.size(-1)).float()
707
+ y = y.view(-1, x.size(-1))
708
+
709
+ if self.loss_beta == 0:
710
+ loss = F.mse_loss(x, y, reduction="none")
711
+ else:
712
+ loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)
713
+
714
+ if self.loss_scale is not None:
715
+ scale = self.loss_scale
716
+ else:
717
+ scale = 1 / math.sqrt(x.size(-1))
718
+
719
+ reg_loss = loss * scale
720
+
721
+ return reg_loss
722
+
723
+ def make_targets(self, y, num_layers):
724
+
725
+ with torch.no_grad():
726
+ target_layer_results = y[-num_layers:]
727
+
728
+ permuted = False
729
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
730
+ target_layer_results = [
731
+ tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT
732
+ ]
733
+ permuted = True
734
+ if self.cfg.batch_norm_target_layer:
735
+ target_layer_results = [
736
+ F.batch_norm(
737
+ tl.float(), running_mean=None, running_var=None, training=True
738
+ )
739
+ for tl in target_layer_results
740
+ ]
741
+ if self.cfg.instance_norm_target_layer:
742
+ target_layer_results = [
743
+ F.instance_norm(tl.float()) for tl in target_layer_results
744
+ ]
745
+ if permuted:
746
+ target_layer_results = [
747
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
748
+ ]
749
+ if self.cfg.layer_norm_target_layer:
750
+ target_layer_results = [
751
+ F.layer_norm(tl.float(), tl.shape[-1:])
752
+ for tl in target_layer_results
753
+ ]
754
+
755
+ y = target_layer_results[0].float()
756
+ for tl in target_layer_results[1:]:
757
+ y.add_(tl.float())
758
+ y = y.div_(len(target_layer_results))
759
+
760
+ if self.cfg.layer_norm_targets:
761
+ y = F.layer_norm(y, y.shape[-1:])
762
+
763
+ if self.cfg.instance_norm_targets:
764
+ y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
765
+
766
+ return y
767
+
768
+ @staticmethod
769
+ def compute_var(y):
770
+ y = y.view(-1, y.size(-1))
771
+ if dist.is_initialized():
772
+ zc = torch.tensor(y.size(0)).cuda()
773
+ zs = y.sum(dim=0)
774
+ zss = (y**2).sum(dim=0)
775
+
776
+ dist.all_reduce(zc)
777
+ dist.all_reduce(zs)
778
+ dist.all_reduce(zss)
779
+
780
+ var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
781
+ return torch.sqrt(var + 1e-6).mean()
782
+ else:
783
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
784
+
785
+ def extract_features(
786
+ self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
787
+ ):
788
+ res = self.forward(
789
+ source,
790
+ mode=mode,
791
+ padding_mask=padding_mask,
792
+ mask=mask,
793
+ features_only=True,
794
+ remove_extra_tokens=remove_extra_tokens,
795
+ )
796
+ return res
797
+
798
+ def remove_pretraining_modules(self, modality=None, keep_decoder=False):
799
+ self.ema = None
800
+ self.cfg.clone_batch = 1
801
+ self.recon_proj = None
802
+
803
+ if not keep_decoder:
804
+ self.shared_decoder = None
805
+
806
+ modality = modality.lower() if modality is not None else None
807
+ for k in list(self.modality_encoders.keys()):
808
+ if modality is not None and k.lower() != modality:
809
+ del self.modality_encoders[k]
810
+ else:
811
+ self.modality_encoders[k].remove_pretraining_modules(
812
+ keep_decoder=keep_decoder
813
+ )
814
+ if not keep_decoder:
815
+ self.modality_encoders[k].decoder = None