|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def align_depth_least_square( |
|
gt_arr: np.ndarray, |
|
pred_arr: np.ndarray, |
|
valid_mask_arr: np.ndarray, |
|
return_scale_shift=True, |
|
max_resolution=None, |
|
): |
|
ori_shape = pred_arr.shape |
|
|
|
gt = gt_arr.squeeze() |
|
pred = pred_arr.squeeze() |
|
valid_mask = valid_mask_arr.squeeze() |
|
|
|
|
|
if max_resolution is not None: |
|
scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) |
|
if scale_factor < 1: |
|
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") |
|
gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() |
|
pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() |
|
valid_mask = ( |
|
downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()) |
|
.bool() |
|
.numpy() |
|
) |
|
|
|
assert ( |
|
gt.shape == pred.shape == valid_mask.shape |
|
), f"{gt.shape}, {pred.shape}, {valid_mask.shape}" |
|
|
|
gt_masked = gt[valid_mask].reshape((-1, 1)) |
|
pred_masked = pred[valid_mask].reshape((-1, 1)) |
|
|
|
|
|
_ones = np.ones_like(pred_masked) |
|
A = np.concatenate([pred_masked, _ones], axis=-1) |
|
X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] |
|
scale, shift = X |
|
|
|
aligned_pred = pred_arr * scale + shift |
|
|
|
|
|
aligned_pred = aligned_pred.reshape(ori_shape) |
|
|
|
if return_scale_shift: |
|
return aligned_pred, scale, shift |
|
else: |
|
return aligned_pred |
|
|
|
|
|
|
|
def depth2disparity(depth, return_mask=False): |
|
if isinstance(depth, torch.Tensor): |
|
disparity = torch.zeros_like(depth) |
|
elif isinstance(depth, np.ndarray): |
|
disparity = np.zeros_like(depth) |
|
non_negtive_mask = depth > 0 |
|
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] |
|
if return_mask: |
|
return disparity, non_negtive_mask |
|
else: |
|
return disparity |
|
|
|
def disparity2depth(disparity, **kwargs): |
|
return depth2disparity(disparity, **kwargs) |
|
|