File size: 1,486 Bytes
ec0c8fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import logging
from . import vision_transformer as vits
logger = logging.getLogger("dinov2")
def build_model(args, only_teacher=False, img_size=224):
args.arch = args.arch.removesuffix("_memeff")
if "vit" in args.arch:
vit_kwargs = dict(
img_size=img_size,
patch_size=args.patch_size,
init_values=args.layerscale,
ffn_layer=args.ffn_layer,
block_chunks=args.block_chunks,
qkv_bias=args.qkv_bias,
proj_bias=args.proj_bias,
ffn_bias=args.ffn_bias,
num_register_tokens=args.num_register_tokens,
interpolate_offset=args.interpolate_offset,
interpolate_antialias=args.interpolate_antialias,
)
teacher = vits.__dict__[args.arch](**vit_kwargs)
if only_teacher:
return teacher, teacher.embed_dim
student = vits.__dict__[args.arch](
**vit_kwargs,
drop_path_rate=args.drop_path_rate,
drop_path_uniform=args.drop_path_uniform,
)
embed_dim = student.embed_dim
return student, teacher, embed_dim
def build_model_from_cfg(cfg, only_teacher=False):
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|