|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import random |
|
import subprocess |
|
from urllib.parse import urlparse |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
|
|
|
|
logger = logging.getLogger("dinov2") |
|
|
|
|
|
def load_pretrained_weights(model, pretrained_weights, checkpoint_key): |
|
if urlparse(pretrained_weights).scheme: |
|
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") |
|
else: |
|
state_dict = torch.load(pretrained_weights, map_location="cpu") |
|
if checkpoint_key is not None and checkpoint_key in state_dict: |
|
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") |
|
state_dict = state_dict[checkpoint_key] |
|
|
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
|
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} |
|
msg = model.load_state_dict(state_dict, strict=False) |
|
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) |
|
|
|
|
|
def fix_random_seeds(seed=31): |
|
""" |
|
Fix random seeds. |
|
""" |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
|
|
def get_sha(): |
|
cwd = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
def _run(command): |
|
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() |
|
|
|
sha = "N/A" |
|
diff = "clean" |
|
branch = "N/A" |
|
try: |
|
sha = _run(["git", "rev-parse", "HEAD"]) |
|
subprocess.check_output(["git", "diff"], cwd=cwd) |
|
diff = _run(["git", "diff-index", "HEAD"]) |
|
diff = "has uncommitted changes" if diff else "clean" |
|
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) |
|
except Exception: |
|
pass |
|
message = f"sha: {sha}, status: {diff}, branch: {branch}" |
|
return message |
|
|
|
|
|
class CosineScheduler(object): |
|
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): |
|
super().__init__() |
|
self.final_value = final_value |
|
self.total_iters = total_iters |
|
|
|
freeze_schedule = np.zeros((freeze_iters)) |
|
|
|
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) |
|
|
|
iters = np.arange(total_iters - warmup_iters - freeze_iters) |
|
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) |
|
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) |
|
|
|
assert len(self.schedule) == self.total_iters |
|
|
|
def __getitem__(self, it): |
|
if it >= self.total_iters: |
|
return self.final_value |
|
else: |
|
return self.schedule[it] |
|
|
|
|
|
def has_batchnorms(model): |
|
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) |
|
for name, module in model.named_modules(): |
|
if isinstance(module, bn_types): |
|
return True |
|
return False |
|
|