File size: 4,328 Bytes
aa60bbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# --------------------------------------------------------
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------'
import math
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
from timm.models.registry import register_model
from torchscale.model.BEiT3 import BEiT3
from torchscale.architecture.config import EncoderConfig
def trunc_normal_(tensor, mean=0., std=1.):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
def _get_base_config(
img_size=224, patch_size=16, drop_path_rate=0,
checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
):
return EncoderConfig(
img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
layernorm_embedding=False, normalize_output=True, no_output_layer=True,
drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12,
encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12,
checkpoint_activations=checkpoint_activations,
)
def _get_large_config(
img_size=224, patch_size=16, drop_path_rate=0,
checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
):
return EncoderConfig(
img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
layernorm_embedding=False, normalize_output=True, no_output_layer=True,
drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16,
encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24,
checkpoint_activations=checkpoint_activations,
)
class BEiT3Wrapper(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args
self.beit3 = BEiT3(args)
self.apply(self._init_weights)
self.mim_head = nn.Linear(1024, 8192)
self.num_img_patches = self.beit3.vision_embed.num_position_embeddings()
self.hidden_size = args.encoder_embed_dim
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def get_num_layers(self):
return self.beit3.encoder.num_layers
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'}
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, pixel_values, query_embed=None):
B = pixel_values.size(0)
dtype = self.beit3.vision_embed.proj.weight.dtype
pixel_values = pixel_values.to(dtype)
token_embeddings = self.beit3.vision_embed(pixel_values)
multiway_split_position = -1
if query_embed is not None:
query_embed = torch.stack([query_embed] * B)
multiway_split_position = token_embeddings.size(1)
token_embeddings = torch.cat([token_embeddings, query_embed], dim=1)
outputs = self.beit3.encoder(
src_tokens=None,
token_embeddings=token_embeddings,
multiway_split_position=multiway_split_position
)
vision_hidden_states = outputs["encoder_out"]
if query_embed is not None:
vision_hidden_states = vision_hidden_states[:, self.num_img_patches:]
return vision_hidden_states
@register_model
def beit3_large_patch16_224(pretrained=False, **kwargs):
args = _get_large_config(img_size=224, **kwargs)
model = BEiT3Wrapper(args, **kwargs)
return model
|