# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. """ import os import subprocess from typing import Any, Dict, List, Optional import torch import torchvision from torch import Tensor def get_sha(): cwd = os.path.dirname(os.path.abspath(__file__)) def _run(command): return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() sha = "N/A" diff = "clean" branch = "N/A" try: sha = _run(["git", "rev-parse", "HEAD"]) subprocess.check_output(["git", "diff"], cwd=cwd) diff = _run(["git", "diff-index", "HEAD"]) diff = "has uncommited changes" if diff else "clean" branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) except Exception: pass message = f"sha: {sha}, status: {diff}, branch: {branch}" return message def collate_fn(do_round, batch): batch = list(zip(*batch)) final_batch = {} final_batch["samples"] = NestedTensor.from_tensor_list(batch[0], do_round) final_batch["targets"] = batch[1] if "positive_map" in batch[1][0]: # we batch the positive maps here # Since in general each batch element will have a different number of boxes, # we collapse a single batch dimension to avoid padding. This is sufficient for our purposes. max_len = max([v["positive_map"].shape[1] for v in batch[1]]) nb_boxes = sum([v["positive_map"].shape[0] for v in batch[1]]) batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool) cur_count = 0 for v in batch[1]: cur_pos = v["positive_map"] batched_pos_map[cur_count : cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos cur_count += len(cur_pos) assert cur_count == len(batched_pos_map) # assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]]) final_batch["positive_map"] = batched_pos_map.float() if "positive_map_eval" in batch[1][0]: # we batch the positive maps here # Since in general each batch element will have a different number of boxes, # we collapse a single batch dimension to avoid padding. This is sufficient for our purposes. max_len = max([v["positive_map_eval"].shape[1] for v in batch[1]]) nb_boxes = sum([v["positive_map_eval"].shape[0] for v in batch[1]]) batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool) cur_count = 0 for v in batch[1]: cur_pos = v["positive_map_eval"] batched_pos_map[cur_count : cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos cur_count += len(cur_pos) assert cur_count == len(batched_pos_map) # assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]]) final_batch["positive_map_eval"] = batched_pos_map.float() if "answer" in batch[1][0] or "answer_type" in batch[1][0]: answers = {} for f in batch[1][0].keys(): if "answer" not in f: continue answers[f] = torch.stack([b[f] for b in batch[1]]) final_batch["answers"] = answers return final_batch class NestedTensor(object): def __init__(self, tensors, mask): self.tensors = tensors self.mask = mask def to(self, *args, **kwargs): cast_tensor = self.tensors.to(*args, **kwargs) cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None return type(self)(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask @classmethod def from_tensor_list(cls, tensor_list, do_round=False): # TODO make this more general if tensor_list[0].ndim == 3: # TODO make it support different-sized images max_size = tuple(max(s) for s in zip(*[img.shape for img in tensor_list])) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = (len(tensor_list),) + max_size b, c, h, w = batch_shape if do_round: # Round to an even size to avoid rounding issues in fpn p = 128 h = h if h % p == 0 else (h // p + 1) * p w = w if w % p == 0 else (w // p + 1) * p batch_shape = b, c, h, w dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False else: raise ValueError("not supported") return cls(tensor, mask) def __repr__(self): return repr(self.tensors) def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor """ Equivalent to nn.functional.interpolate, but with support for empty channel sizes. """ if input.numel() > 0: return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) assert input.shape[0] != 0 or input.shape[1] != 0, "At least one of the two first dimensions must be non zero" if input.shape[1] == 0: # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim return torch.nn.functional.interpolate(input.transpose(0, 1), size, scale_factor, mode, align_corners).transpose(0, 1) # empty batch dimension is now supported in pytorch return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) def targets_to(targets: List[Dict[str, Any]], device): """Moves the target dicts to the given device.""" excluded_keys = [ "questionId", "tokens_positive", "tokens", "dataset_name", "sentence_id", "original_img_id", "nb_eval", "task_id", "original_id", ] return [{k: v.to(device) if k not in excluded_keys else v for k, v in t.items() if k != "caption"} for t in targets]