from typing import * import numpy as np import torch from scipy.optimize import linear_sum_assignment from torch.nn.utils.rnn import pad_sequence def num2mask( nums: torch.Tensor, max_length: Optional[int] = None ) -> torch.Tensor: """ E.g. input a tensor [2, 3, 4], return [[T T F F], [T T T F], [T T T T]] :param nums: Shape [batch] :param max_length: maximum length. if not provided, will choose the largest number from nums. :return: 2D binary mask. """ shape_backup = nums.shape nums = nums.flatten() max_length = max_length or int(nums.max()) batch_size = len(nums) range_nums = torch.arange(0, max_length, device=nums.device).unsqueeze(0).expand([batch_size, max_length]) ret = (range_nums.T < nums).T return ret.reshape(*shape_backup, max_length) def mask2idx( mask: torch.Tensor, max_length: Optional[int] = None, padding_value: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ E.g. input a tensor [[T T F F], [T T T F], [F F F T]] with padding value -1, return [[0, 1, -1], [0, 1, 2], [3, -1, -1]] :param mask: Mask tensor. Boolean. Not necessarily to be 2D. :param max_length: If provided, will truncate. :param padding_value: Padding value. Default to 0. :return: Index tensor. """ shape_prefix, mask_length = mask.shape[:-1], mask.shape[-1] flat_mask = mask.flatten(0, -2) index_list = [torch.arange(mask_length, device=mask.device)[one_mask] for one_mask in flat_mask.unbind(0)] index_tensor = pad_sequence(index_list, batch_first=True, padding_value=padding_value) if max_length is not None: index_tensor = index_tensor[:, :max_length] index_tensor = index_tensor.reshape(*shape_prefix, -1) return index_tensor, mask.sum(-1) def one_hot(tags: torch.Tensor, num_tags: Optional[int] = None) -> torch.Tensor: num_tags = num_tags or int(tags.max()) ret = tags.new_zeros(size=[*tags.shape, num_tags], dtype=torch.bool) ret.scatter_(2, tags.unsqueeze(2), tags.new_ones([*tags.shape, 1], dtype=torch.bool)) return ret def numpy2torch( dict_obj: dict ) -> dict: """ Convert list/np.ndarray data to torch.Tensor and add add a batch dim. """ ret = dict() for k, v in dict_obj.items(): if isinstance(v, list) or isinstance(v, np.ndarray): ret[k] = torch.tensor(v).unsqueeze(0) else: ret[k] = v return ret def max_match(mat: np.ndarray): row_idx, col_idx = linear_sum_assignment(mat, True) return mat[row_idx, col_idx].sum()