|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
The ray sampler is a module that takes in camera matrices and resolution and batches of rays. |
|
Expects cam2world matrices that use the OpenCV camera coordinate system conventions. |
|
""" |
|
|
|
import torch |
|
from pdb import set_trace as st |
|
import random |
|
|
|
HUGE_NUMBER = 1e10 |
|
TINY_NUMBER = 1e-6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
def depth2pts_outside(ray_o, ray_d, depth): |
|
''' |
|
ray_o, ray_d: [..., 3] |
|
depth: [...]; inverse of distance to sphere origin |
|
''' |
|
|
|
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) |
|
p_mid = ray_o + d1.unsqueeze(-1) * ray_d |
|
p_mid_norm = torch.norm(p_mid, dim=-1) |
|
ray_d_cos = 1. / torch.norm(ray_d, dim=-1) |
|
d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos |
|
p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d |
|
|
|
rot_axis = torch.cross(ray_o, p_sphere, dim=-1) |
|
rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) |
|
phi = torch.asin(p_mid_norm) |
|
theta = torch.asin(p_mid_norm * depth) |
|
rot_angle = (phi - theta).unsqueeze(-1) |
|
|
|
|
|
|
|
p_sphere_new = p_sphere * torch.cos(rot_angle) + \ |
|
torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \ |
|
rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle)) |
|
p_sphere_new = p_sphere_new / torch.norm( |
|
p_sphere_new, dim=-1, keepdim=True) |
|
pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1) |
|
|
|
|
|
depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1 |
|
return pts, depth_real |
|
|
|
|
|
class RaySampler(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None |
|
|
|
def create_patch_uv(self, |
|
patch_resolution, |
|
resolution, |
|
cam2world_matrix, |
|
fg_bbox=None): |
|
|
|
def sample_patch_uv(fg_bbox=None): |
|
assert patch_resolution <= resolution |
|
|
|
def sample_patch_range(): |
|
patch_reolution_start = random.randint( |
|
0, resolution - |
|
patch_resolution) |
|
|
|
return patch_reolution_start |
|
|
|
def sample_patch_range_oversample_boundary(range_start=None, |
|
range_end=None): |
|
|
|
if range_start is None: |
|
|
|
range_start = patch_resolution |
|
if range_end is None: |
|
|
|
range_end = resolution + patch_resolution |
|
|
|
|
|
patch_reolution_end = random.randint( |
|
range_start, |
|
range_end, |
|
) |
|
|
|
|
|
if patch_reolution_end <= patch_resolution: |
|
patch_reolution_end = patch_resolution |
|
elif patch_reolution_end > resolution: |
|
patch_reolution_end = resolution |
|
|
|
|
|
return patch_reolution_end |
|
|
|
|
|
|
|
if fg_bbox is not None and random.random( |
|
) > 0.125: |
|
|
|
|
|
top_min, left_min = fg_bbox[:, :2].min(dim=0, |
|
keepdim=True)[0][0] |
|
height_max, width_max = fg_bbox[:, 2:].max(dim=0, |
|
keepdim=True)[0][0] |
|
|
|
if top_min + patch_resolution < height_max: |
|
h_end = sample_patch_range_oversample_boundary( |
|
top_min + patch_resolution, height_max) |
|
else: |
|
h_end = max( |
|
height_max.to(torch.uint8).item(), patch_resolution) |
|
if left_min + patch_resolution < width_max: |
|
w_end = sample_patch_range_oversample_boundary( |
|
left_min + patch_resolution, width_max) |
|
else: |
|
w_end = max( |
|
width_max.to(torch.uint8).item(), patch_resolution) |
|
|
|
h_start = h_end - patch_resolution |
|
w_start = w_end - patch_resolution |
|
|
|
try: |
|
assert h_start >= 0 and w_start >= 0 |
|
except: |
|
st() |
|
|
|
else: |
|
h_end = sample_patch_range_oversample_boundary() |
|
h_start = h_end - patch_resolution |
|
w_end = sample_patch_range_oversample_boundary() |
|
w_start = w_end - patch_resolution |
|
|
|
assert h_start >= 0 and w_start >= 0 |
|
|
|
uv = torch.stack( |
|
torch.meshgrid( |
|
torch.arange( |
|
start=h_start, |
|
|
|
end=h_end, |
|
dtype=torch.float32, |
|
device=cam2world_matrix.device), |
|
torch.arange( |
|
start=w_start, |
|
|
|
end=w_end, |
|
dtype=torch.float32, |
|
device=cam2world_matrix.device), |
|
indexing='ij')) * (1. / resolution) + (0.5 / resolution) |
|
|
|
uv = uv.flip(0).reshape(2, -1).transpose(1, 0) |
|
|
|
return uv, (h_start, w_start, patch_resolution, patch_resolution |
|
) |
|
|
|
all_uv = [] |
|
ray_bboxes = [] |
|
for _ in range(cam2world_matrix.shape[0]): |
|
uv, bbox = sample_patch_uv(fg_bbox) |
|
all_uv.append(uv) |
|
ray_bboxes.append(bbox) |
|
|
|
all_uv = torch.stack(all_uv, 0) |
|
|
|
|
|
return all_uv, ray_bboxes |
|
|
|
def create_uv(self, resolution, cam2world_matrix): |
|
|
|
uv = torch.stack( |
|
torch.meshgrid(torch.arange(resolution, |
|
dtype=torch.float32, |
|
device=cam2world_matrix.device), |
|
torch.arange(resolution, |
|
dtype=torch.float32, |
|
device=cam2world_matrix.device), |
|
indexing='ij')) * (1. / resolution) + (0.5 / |
|
resolution) |
|
|
|
uv = uv.flip(0).reshape(2, -1).transpose(1, 0) |
|
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) |
|
|
|
return uv |
|
|
|
def forward(self, cam2world_matrix, intrinsics, resolution, fg_mask=None): |
|
""" |
|
Create batches of rays and return origins and directions. |
|
|
|
cam2world_matrix: (N, 4, 4) |
|
intrinsics: (N, 3, 3) |
|
resolution: int |
|
|
|
ray_origins: (N, M, 3) |
|
ray_dirs: (N, M, 2) |
|
""" |
|
N, M = cam2world_matrix.shape[0], resolution**2 |
|
cam_locs_world = cam2world_matrix[:, :3, 3] |
|
fx = intrinsics[:, 0, 0] |
|
fy = intrinsics[:, 1, 1] |
|
cx = intrinsics[:, 0, 2] |
|
cy = intrinsics[:, 1, 2] |
|
sk = intrinsics[:, 0, 1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv = self.create_uv( |
|
resolution, |
|
cam2world_matrix, |
|
) |
|
|
|
x_cam = uv[:, :, 0].view(N, -1) |
|
y_cam = uv[:, :, 1].view(N, -1) |
|
z_cam = torch.ones((N, M), device=cam2world_matrix.device) |
|
|
|
|
|
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1) * |
|
sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1) * |
|
y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam |
|
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam |
|
|
|
cam_rel_points = torch.stack( |
|
(x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) |
|
|
|
|
|
|
|
world_rel_points = torch.bmm(cam2world_matrix, |
|
cam_rel_points.permute(0, 2, 1)).permute( |
|
0, 2, 1)[:, :, :3] |
|
|
|
ray_dirs = world_rel_points - cam_locs_world[:, None, :] |
|
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) |
|
|
|
ray_origins = cam_locs_world.unsqueeze(1).repeat( |
|
1, ray_dirs.shape[1], 1) |
|
|
|
return ray_origins, ray_dirs, None |
|
|
|
|
|
class PatchRaySampler(RaySampler): |
|
|
|
def forward(self, |
|
cam2world_matrix, |
|
intrinsics, |
|
patch_resolution, |
|
resolution, |
|
fg_bbox=None): |
|
""" |
|
Create batches of rays and return origins and directions. |
|
|
|
cam2world_matrix: (N, 4, 4) |
|
intrinsics: (N, 3, 3) |
|
resolution: int |
|
|
|
ray_origins: (N, M, 3) |
|
ray_dirs: (N, M, 2) |
|
""" |
|
N, M = cam2world_matrix.shape[0], patch_resolution**2 |
|
cam_locs_world = cam2world_matrix[:, :3, 3] |
|
fx = intrinsics[:, 0, 0] |
|
fy = intrinsics[:, 1, 1] |
|
cx = intrinsics[:, 0, 2] |
|
cy = intrinsics[:, 1, 2] |
|
sk = intrinsics[:, 0, 1] |
|
|
|
|
|
|
|
|
|
all_uv_list = [] |
|
ray_bboxes = [] |
|
for idx in range(N): |
|
uv, bboxes = self.create_patch_uv( |
|
patch_resolution, resolution, cam2world_matrix[idx:idx + 1], |
|
fg_bbox[idx:idx + 1] |
|
if fg_bbox is not None else None) |
|
all_uv_list.append( |
|
uv |
|
|
|
) |
|
ray_bboxes.extend(bboxes) |
|
all_uv = torch.cat(all_uv_list, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_cam = all_uv[:, :, 0].view(N, -1) |
|
y_cam = all_uv[:, :, 1].view(N, -1) |
|
z_cam = torch.ones((N, M), device=cam2world_matrix.device) |
|
|
|
|
|
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1) * |
|
sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1) * |
|
y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam |
|
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam |
|
|
|
cam_rel_points = torch.stack( |
|
(x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) |
|
|
|
world_rel_points = torch.bmm(cam2world_matrix, |
|
cam_rel_points.permute(0, 2, 1)).permute( |
|
0, 2, 1)[:, :, :3] |
|
|
|
ray_dirs = world_rel_points - cam_locs_world[:, None, :] |
|
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) |
|
|
|
ray_origins = cam_locs_world.unsqueeze(1).repeat( |
|
1, ray_dirs.shape[1], 1) |
|
|
|
return ray_origins, ray_dirs, ray_bboxes |
|
|