Gosse Minnema
Re-enable LOME
2890e34
from typing import *
import numpy as np
import torch
from scipy.optimize import linear_sum_assignment
from torch.nn.utils.rnn import pad_sequence
def num2mask(
nums: torch.Tensor,
max_length: Optional[int] = None
) -> torch.Tensor:
"""
E.g. input a tensor [2, 3, 4], return [[T T F F], [T T T F], [T T T T]]
:param nums: Shape [batch]
:param max_length: maximum length. if not provided, will choose the largest number from nums.
:return: 2D binary mask.
"""
shape_backup = nums.shape
nums = nums.flatten()
max_length = max_length or int(nums.max())
batch_size = len(nums)
range_nums = torch.arange(0, max_length, device=nums.device).unsqueeze(0).expand([batch_size, max_length])
ret = (range_nums.T < nums).T
return ret.reshape(*shape_backup, max_length)
def mask2idx(
mask: torch.Tensor,
max_length: Optional[int] = None,
padding_value: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
E.g. input a tensor [[T T F F], [T T T F], [F F F T]] with padding value -1,
return [[0, 1, -1], [0, 1, 2], [3, -1, -1]]
:param mask: Mask tensor. Boolean. Not necessarily to be 2D.
:param max_length: If provided, will truncate.
:param padding_value: Padding value. Default to 0.
:return: Index tensor.
"""
shape_prefix, mask_length = mask.shape[:-1], mask.shape[-1]
flat_mask = mask.flatten(0, -2)
index_list = [torch.arange(mask_length, device=mask.device)[one_mask] for one_mask in flat_mask.unbind(0)]
index_tensor = pad_sequence(index_list, batch_first=True, padding_value=padding_value)
if max_length is not None:
index_tensor = index_tensor[:, :max_length]
index_tensor = index_tensor.reshape(*shape_prefix, -1)
return index_tensor, mask.sum(-1)
def one_hot(tags: torch.Tensor, num_tags: Optional[int] = None) -> torch.Tensor:
num_tags = num_tags or int(tags.max())
ret = tags.new_zeros(size=[*tags.shape, num_tags], dtype=torch.bool)
ret.scatter_(2, tags.unsqueeze(2), tags.new_ones([*tags.shape, 1], dtype=torch.bool))
return ret
def numpy2torch(
dict_obj: dict
) -> dict:
"""
Convert list/np.ndarray data to torch.Tensor and add add a batch dim.
"""
ret = dict()
for k, v in dict_obj.items():
if isinstance(v, list) or isinstance(v, np.ndarray):
ret[k] = torch.tensor(v).unsqueeze(0)
else:
ret[k] = v
return ret
def max_match(mat: np.ndarray):
row_idx, col_idx = linear_sum_assignment(mat, True)
return mat[row_idx, col_idx].sum()