Spaces:
Paused
Paused
import torch | |
from torch.autograd import Function | |
from ..utils import ext_loader | |
ext_module = ext_loader.load_ext('_ext', ['knn_forward']) | |
class KNN(Function): | |
r"""KNN (CUDA) based on heap data structure. | |
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/ | |
scene_seg/lib/pointops/src/knnquery_heap>`_. | |
Find k-nearest points. | |
""" | |
def forward(ctx, | |
k: int, | |
xyz: torch.Tensor, | |
center_xyz: torch.Tensor = None, | |
transposed: bool = False) -> torch.Tensor: | |
""" | |
Args: | |
k (int): number of nearest neighbors. | |
xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N). | |
xyz coordinates of the features. | |
center_xyz (Tensor, optional): (B, npoint, 3) if transposed == | |
False, else (B, 3, npoint). centers of the knn query. | |
Default: None. | |
transposed (bool, optional): whether the input tensors are | |
transposed. Should not explicitly use this keyword when | |
calling knn (=KNN.apply), just add the fourth param. | |
Default: False. | |
Returns: | |
Tensor: (B, k, npoint) tensor with the indices of | |
the features that form k-nearest neighbours. | |
""" | |
assert (k > 0) & (k < 100), 'k should be in range(0, 100)' | |
if center_xyz is None: | |
center_xyz = xyz | |
if transposed: | |
xyz = xyz.transpose(2, 1).contiguous() | |
center_xyz = center_xyz.transpose(2, 1).contiguous() | |
assert xyz.is_contiguous() # [B, N, 3] | |
assert center_xyz.is_contiguous() # [B, npoint, 3] | |
center_xyz_device = center_xyz.get_device() | |
assert center_xyz_device == xyz.get_device(), \ | |
'center_xyz and xyz should be put on the same device' | |
if torch.cuda.current_device() != center_xyz_device: | |
torch.cuda.set_device(center_xyz_device) | |
B, npoint, _ = center_xyz.shape | |
N = xyz.shape[1] | |
idx = center_xyz.new_zeros((B, npoint, k)).int() | |
dist2 = center_xyz.new_zeros((B, npoint, k)).float() | |
ext_module.knn_forward( | |
xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k) | |
# idx shape to [B, k, npoint] | |
idx = idx.transpose(2, 1).contiguous() | |
if torch.__version__ != 'parrots': | |
ctx.mark_non_differentiable(idx) | |
return idx | |
def backward(ctx, a=None): | |
return None, None, None | |
knn = KNN.apply | |