|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.init import trunc_normal_ |
|
from torch.nn.utils import weight_norm |
|
|
|
|
|
class DINOHead(nn.Module): |
|
def __init__( |
|
self, |
|
in_dim, |
|
out_dim, |
|
use_bn=False, |
|
nlayers=3, |
|
hidden_dim=2048, |
|
bottleneck_dim=256, |
|
mlp_bias=True, |
|
): |
|
super().__init__() |
|
nlayers = max(nlayers, 1) |
|
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) |
|
self.apply(self._init_weights) |
|
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) |
|
self.last_layer.weight_g.data.fill_(1) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
x = self.mlp(x) |
|
eps = 1e-6 if x.dtype == torch.float16 else 1e-12 |
|
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) |
|
x = self.last_layer(x) |
|
return x |
|
|
|
|
|
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): |
|
if nlayers == 1: |
|
return nn.Linear(in_dim, bottleneck_dim, bias=bias) |
|
else: |
|
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] |
|
if use_bn: |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
layers.append(nn.GELU()) |
|
for _ in range(nlayers - 2): |
|
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) |
|
if use_bn: |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
layers.append(nn.GELU()) |
|
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) |
|
return nn.Sequential(*layers) |
|
|