# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import math from dataclasses import dataclass, field from typing import Optional, Callable from functools import partial import numpy as np from omegaconf import II import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from fairseq.modules import EMAModule, EMAModuleConfig from fairseq.dataclass import FairseqDataclass from fairseq.models import BaseFairseqModel, register_model from examples.data2vec.data.modality import Modality from examples.data2vec.models.modalities.base import ( MaskSeed, D2vModalityConfig, ModalitySpecificEncoder, get_annealed_rate, ) from examples.data2vec.models.modalities.modules import ( D2vDecoderConfig, AltBlock, Decoder1d, ) from .modalities.audio import ( D2vAudioConfig, AudioEncoder, ) from examples.data2vec.models.modalities.images import ( D2vImageConfig, ImageEncoder, ) from examples.data2vec.models.modalities.text import ( D2vTextConfig, TextEncoder, ) logger = logging.getLogger(__name__) @dataclass class D2vModalitiesConfig(FairseqDataclass): audio: D2vAudioConfig = D2vAudioConfig() image: D2vImageConfig = D2vImageConfig() text: D2vTextConfig = D2vTextConfig() @dataclass class Data2VecMultiConfig(FairseqDataclass): loss_beta: float = field( default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} ) loss_scale: Optional[float] = field( default=None, metadata={ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" }, ) input_feature_ndim: int = 40 depth: int = 8 start_drop_path_rate: float = 0 end_drop_path_rate: float = 0 num_heads: int = 12 norm_eps: float = 1e-6 norm_affine: bool = True encoder_dropout: float = 0.1 post_mlp_drop: float = 0.1 attention_dropout: float = 0.1 activation_dropout: float = 0.0 dropout_input: float = 0.0 layerdrop: float = 0.0 embed_dim: int = 768 mlp_ratio: float = 4 layer_norm_first: bool = False average_top_k_layers: int = field( default=8, metadata={"help": "how many layers to average"} ) end_of_block_targets: bool = False clone_batch: int = 1 layer_norm_target_layer: bool = False batch_norm_target_layer: bool = False instance_norm_target_layer: bool = False instance_norm_targets: bool = False layer_norm_targets: bool = False ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) ema_same_dtype: bool = True log_norms: bool = True ema_end_decay: float = field( default=0.9999, metadata={"help": "final ema decay rate"} ) # when to finish annealing ema decay rate ema_anneal_end_step: int = II("optimization.max_update") ema_encoder_only: bool = field( default=True, metadata={ "help": "whether to momentum update only the shared transformer encoder" }, ) max_update: int = II("optimization.max_update") modalities: D2vModalitiesConfig = D2vModalitiesConfig() shared_decoder: Optional[D2vDecoderConfig] = None min_target_var: float = field( default=0.1, metadata={"help": "stop training if target var falls below this"} ) min_pred_var: float = field( default=0.01, metadata={"help": "stop training if prediction var falls below this"}, ) supported_modality: Optional[Modality] = None mae_init: bool = False seed: int = II("common.seed") skip_ema: bool = False cls_loss: float = 0 recon_loss: float = 0 d2v_loss: float = 1 decoder_group: bool = False @register_model("data2vec_multi", dataclass=Data2VecMultiConfig) class Data2VecMultiModel(BaseFairseqModel): def make_modality_encoder( self, cfg: D2vModalityConfig, embed_dim: int, make_block: Callable[[float], nn.ModuleList], norm_layer: Callable[[int], nn.LayerNorm], layer_norm_first: bool, alibi_biases, task, ) -> ModalitySpecificEncoder: if cfg.type == Modality.AUDIO: enc_cls = AudioEncoder elif cfg.type == Modality.IMAGE: enc_cls = ImageEncoder elif cfg.type == Modality.TEXT: enc_cls = TextEncoder if hasattr(task, "text_task"): task = task.text_task else: raise Exception(f"unsupported modality {cfg.type}") return enc_cls( cfg, embed_dim, make_block, norm_layer, layer_norm_first, alibi_biases, task, ) def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None): super().__init__() self.cfg = cfg self.modalities = modalities self.task = task make_layer_norm = partial( nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine ) def make_block(drop_path, dim=None, heads=None): return AltBlock( cfg.embed_dim if dim is None else dim, cfg.num_heads if heads is None else heads, cfg.mlp_ratio, qkv_bias=True, drop=cfg.encoder_dropout, attn_drop=cfg.attention_dropout, mlp_drop=cfg.activation_dropout, post_mlp_drop=cfg.post_mlp_drop, drop_path=drop_path, norm_layer=make_layer_norm, layer_norm_first=cfg.layer_norm_first, ffn_targets=not cfg.end_of_block_targets, ) self.alibi_biases = {} self.modality_encoders = nn.ModuleDict() for mod in self.modalities: mod_cfg = getattr(cfg.modalities, mod.name.lower()) enc = self.make_modality_encoder( mod_cfg, cfg.embed_dim, make_block, make_layer_norm, cfg.layer_norm_first, self.alibi_biases, task, ) self.modality_encoders[mod.name] = enc self.ema = None self.average_top_k_layers = cfg.average_top_k_layers self.loss_beta = cfg.loss_beta self.loss_scale = cfg.loss_scale self.dropout_input = nn.Dropout(cfg.dropout_input) dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth) self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) self.norm = None if cfg.layer_norm_first: self.norm = make_layer_norm(cfg.embed_dim) if self.cfg.mae_init: self.apply(self._init_weights) else: from fairseq.modules.transformer_sentence_encoder import init_bert_params self.apply(init_bert_params) for mod_enc in self.modality_encoders.values(): mod_enc.reset_parameters() if not skip_ema: self.ema = self.make_ema_teacher(cfg.ema_decay) self.shared_decoder = ( Decoder1d(cfg.shared_decoder, cfg.embed_dim) if self.cfg.shared_decoder is not None else None ) if self.shared_decoder is not None: self.shared_decoder.apply(self._init_weights) self.recon_proj = None if cfg.recon_loss > 0: self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim) for pn, p in self.named_parameters(): if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn: p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} if cfg.decoder_group and "decoder" in pn: p.param_group = "decoder" self.num_updates = 0 def _init_weights(self, m): try: from apex.normalization import FusedLayerNorm fn = FusedLayerNorm except: fn = nn.LayerNorm if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm) or isinstance(m, fn): if m.bias is not None: nn.init.constant_(m.bias, 0) if m.weight is not None: nn.init.constant_(m.weight, 1.0) @torch.no_grad() def make_ema_teacher(self, ema_decay): ema_config = EMAModuleConfig( ema_decay=ema_decay, ema_fp32=True, log_norms=self.cfg.log_norms, add_missing_params=False, ) model_copy = self.make_target_model() return EMAModule( model_copy, ema_config, copy_model=False, ) def make_target_model(self): logger.info("making target model") model_copy = Data2VecMultiModel( self.cfg, self.modalities, skip_ema=True, task=self.task ) if self.cfg.ema_encoder_only: model_copy = model_copy.blocks for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()): p_t.data.copy_(p_s.data) else: for p_s, p_t in zip(self.parameters(), model_copy.parameters()): p_t.data.copy_(p_s.data) for mod_enc in model_copy.modality_encoders.values(): mod_enc.decoder = None if not mod_enc.modality_cfg.ema_local_encoder: mod_enc.local_encoder = None mod_enc.project_features = None model_copy.requires_grad_(False) return model_copy def set_num_updates(self, num_updates): super().set_num_updates(num_updates) if self.ema is not None and ( (self.num_updates == 0 and num_updates > 1) or self.num_updates >= num_updates ): pass elif self.training and self.ema is not None: ema_weight_decay = None if self.cfg.ema_decay != self.cfg.ema_end_decay: if num_updates >= self.cfg.ema_anneal_end_step: decay = self.cfg.ema_end_decay else: decay = get_annealed_rate( self.cfg.ema_decay, self.cfg.ema_end_decay, num_updates, self.cfg.ema_anneal_end_step, ) self.ema.set_decay(decay, weight_decay=ema_weight_decay) if self.ema.get_decay() < 1: self.ema.step(self.blocks if self.cfg.ema_encoder_only else self) self.num_updates = num_updates def state_dict(self, destination=None, prefix="", keep_vars=False): state = super().state_dict(destination, prefix, keep_vars) if self.ema is not None: state[prefix + "_ema"] = self.ema.fp32_params return state def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): k = prefix + "_ema" if self.ema is not None: assert k in state_dict self.ema.restore(state_dict[k], True) del state_dict[k] elif k in state_dict: del state_dict[k] return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) @classmethod def build_model(cls, cfg: Data2VecMultiConfig, task=None): """Build a new model instance.""" if task is None or not hasattr(task, "supported_modalities"): modalities = ( [cfg.supported_modality] if cfg.supported_modality is not None else [ Modality.AUDIO, Modality.IMAGE, Modality.TEXT, ] ) else: modalities = task.supported_modalities return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema) def forward( self, source, target=None, id=None, mode=None, padding_mask=None, mask=True, features_only=False, force_remove_masked=False, remove_extra_tokens=True, precomputed_mask=None, corpus_key=None, # for config compatiblity ): if mode is None: assert self.cfg.supported_modality is not None mode = self.cfg.supported_modality if isinstance(mode, Modality): mode = mode.name feature_extractor = self.modality_encoders[mode] mask_seeds = None if id is not None: mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id) extractor_out = feature_extractor( source, padding_mask, mask, remove_masked=not features_only or force_remove_masked, clone_batch=self.cfg.clone_batch if not features_only else 1, mask_seeds=mask_seeds, precomputed_mask=precomputed_mask, ) x = extractor_out["x"] encoder_mask = extractor_out["encoder_mask"] masked_padding_mask = extractor_out["padding_mask"] masked_alibi_bias = extractor_out.get("alibi_bias", None) alibi_scale = extractor_out.get("alibi_scale", None) if self.dropout_input is not None: x = self.dropout_input(x) layer_results = [] for i, blk in enumerate(self.blocks): if ( not self.training or self.cfg.layerdrop == 0 or (np.random.random() > self.cfg.layerdrop) ): ab = masked_alibi_bias if ab is not None and alibi_scale is not None: scale = ( alibi_scale[i] if alibi_scale.size(0) > 1 else alibi_scale.squeeze(0) ) ab = ab * scale.type_as(ab) x, lr = blk( x, padding_mask=masked_padding_mask, alibi_bias=ab, ) if features_only: layer_results.append((x, lr)) if self.norm is not None: x = self.norm(x) if features_only: if remove_extra_tokens: x = x[:, feature_extractor.modality_cfg.num_extra_tokens :] if masked_padding_mask is not None: masked_padding_mask = masked_padding_mask[ :, feature_extractor.modality_cfg.num_extra_tokens : ] return { "x": x, "padding_mask": masked_padding_mask, "layer_results": layer_results, "mask": encoder_mask, } xs = [] if self.shared_decoder is not None: dx = self.forward_decoder( x, feature_extractor, self.shared_decoder, encoder_mask, ) xs.append(dx) if feature_extractor.decoder is not None: dx = self.forward_decoder( x, feature_extractor, feature_extractor.decoder, encoder_mask, ) xs.append(dx) orig_x = x assert len(xs) > 0 p = next(self.ema.model.parameters()) device = x.device dtype = x.dtype ema_device = p.device ema_dtype = p.dtype if not self.cfg.ema_same_dtype: dtype = ema_dtype if ema_device != device or ema_dtype != dtype: logger.info(f"adjusting ema dtype to {dtype} and device to {device}") self.ema.model = self.ema.model.to(dtype=dtype, device=device) ema_dtype = dtype def to_device(d): for k, p in d.items(): if isinstance(d[k], dict): to_device(d[k]) else: d[k] = p.to(device=device) to_device(self.ema.fp32_params) tm = self.ema.model with torch.no_grad(): tm.eval() if self.cfg.ema_encoder_only: assert target is None ema_input = extractor_out["local_features"] ema_input = feature_extractor.contextualized_features( ema_input.to(dtype=ema_dtype), padding_mask, mask=False, remove_masked=False, ) ema_blocks = tm else: ema_blocks = tm.blocks if feature_extractor.modality_cfg.ema_local_encoder: inp = ( target.to(dtype=ema_dtype) if target is not None else source.to(dtype=ema_dtype) ) ema_input = tm.modality_encoders[mode]( inp, padding_mask, mask=False, remove_masked=False, ) else: assert target is None ema_input = extractor_out["local_features"] ema_feature_enc = tm.modality_encoders[mode] ema_input = ema_feature_enc.contextualized_features( ema_input.to(dtype=ema_dtype), padding_mask, mask=False, remove_masked=False, ) ema_padding_mask = ema_input["padding_mask"] ema_alibi_bias = ema_input.get("alibi_bias", None) ema_alibi_scale = ema_input.get("alibi_scale", None) ema_input = ema_input["x"] y = [] ema_x = [] extra_tokens = feature_extractor.modality_cfg.num_extra_tokens for i, blk in enumerate(ema_blocks): ab = ema_alibi_bias if ab is not None and alibi_scale is not None: scale = ( ema_alibi_scale[i] if ema_alibi_scale.size(0) > 1 else ema_alibi_scale.squeeze(0) ) ab = ab * scale.type_as(ab) ema_input, lr = blk( ema_input, padding_mask=ema_padding_mask, alibi_bias=ab, ) y.append(lr[:, extra_tokens:]) ema_x.append(ema_input[:, extra_tokens:]) y = self.make_targets(y, self.average_top_k_layers) orig_targets = y if self.cfg.clone_batch > 1: y = y.repeat_interleave(self.cfg.clone_batch, 0) masked = encoder_mask.mask.unsqueeze(-1) masked_b = encoder_mask.mask.bool() y = y[masked_b] if xs[0].size(1) == masked_b.size(1): xs = [x[masked_b] for x in xs] else: xs = [x.reshape(-1, x.size(-1)) for x in xs] sample_size = masked.sum().long() result = { "losses": {}, "sample_size": sample_size, } sample_size = result["sample_size"] if self.cfg.cls_loss > 0: assert extra_tokens > 0 cls_target = orig_targets.mean(dim=1) if self.cfg.clone_batch > 1: cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0) cls_pred = x[:, extra_tokens - 1] result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * ( self.cfg.cls_loss * sample_size ) if self.cfg.recon_loss > 0: with torch.no_grad(): target = feature_extractor.patchify(source) mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.0e-6) ** 0.5 if self.cfg.clone_batch > 1: target = target.repeat_interleave(self.cfg.clone_batch, 0) if masked_b is not None: target = target[masked_b] recon = xs[0] if self.recon_proj is not None: recon = self.recon_proj(recon) result["losses"]["recon"] = ( self.d2v_loss(recon, target.float()) * self.cfg.recon_loss ) if self.cfg.d2v_loss > 0: for i, x in enumerate(xs): reg_loss = self.d2v_loss(x, y) n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression" result["losses"][n] = reg_loss * self.cfg.d2v_loss suffix = "" if len(self.modalities) == 1 else f"_{mode}" with torch.no_grad(): if encoder_mask is not None: result["masked_pct"] = 1 - ( encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1) ) for i, x in enumerate(xs): n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}" result[n] = self.compute_var(x.float()) if self.ema is not None: for k, v in self.ema.logs.items(): result[k] = v y = y.float() result[f"target_var{suffix}"] = self.compute_var(y) if self.num_updates > 5000: if result[f"target_var{suffix}"] < self.cfg.min_target_var: logger.error( f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" ) raise Exception( f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" ) for k in result.keys(): if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var: logger.error( f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" ) raise Exception( f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" ) result["ema_decay"] = self.ema.get_decay() * 1000 return result def forward_decoder( self, x, feature_extractor, decoder, mask_info, ): x = feature_extractor.decoder_input(x, mask_info) x = decoder(*x) return x def d2v_loss(self, x, y): x = x.view(-1, x.size(-1)).float() y = y.view(-1, x.size(-1)) if self.loss_beta == 0: loss = F.mse_loss(x, y, reduction="none") else: loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta) if self.loss_scale is not None: scale = self.loss_scale else: scale = 1 / math.sqrt(x.size(-1)) reg_loss = loss * scale return reg_loss def make_targets(self, y, num_layers): with torch.no_grad(): target_layer_results = y[-num_layers:] permuted = False if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: target_layer_results = [ tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT ] permuted = True if self.cfg.batch_norm_target_layer: target_layer_results = [ F.batch_norm( tl.float(), running_mean=None, running_var=None, training=True ) for tl in target_layer_results ] if self.cfg.instance_norm_target_layer: target_layer_results = [ F.instance_norm(tl.float()) for tl in target_layer_results ] if permuted: target_layer_results = [ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC ] if self.cfg.layer_norm_target_layer: target_layer_results = [ F.layer_norm(tl.float(), tl.shape[-1:]) for tl in target_layer_results ] y = target_layer_results[0].float() for tl in target_layer_results[1:]: y.add_(tl.float()) y = y.div_(len(target_layer_results)) if self.cfg.layer_norm_targets: y = F.layer_norm(y, y.shape[-1:]) if self.cfg.instance_norm_targets: y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2) return y @staticmethod def compute_var(y): y = y.view(-1, y.size(-1)) if dist.is_initialized(): zc = torch.tensor(y.size(0)).cuda() zs = y.sum(dim=0) zss = (y**2).sum(dim=0) dist.all_reduce(zc) dist.all_reduce(zs) dist.all_reduce(zss) var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1)) return torch.sqrt(var + 1e-6).mean() else: return torch.sqrt(y.var(dim=0) + 1e-6).mean() def extract_features( self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True ): res = self.forward( source, mode=mode, padding_mask=padding_mask, mask=mask, features_only=True, remove_extra_tokens=remove_extra_tokens, ) return res def remove_pretraining_modules(self, modality=None, keep_decoder=False): self.ema = None self.cfg.clone_batch = 1 self.recon_proj = None if not keep_decoder: self.shared_decoder = None modality = modality.lower() if modality is not None else None for k in list(self.modality_encoders.keys()): if modality is not None and k.lower() != modality: del self.modality_encoders[k] else: self.modality_encoders[k].remove_pretraining_modules( keep_decoder=keep_decoder ) if not keep_decoder: self.modality_encoders[k].decoder = None