import torch
import numpy as np
from scipy.spatial import KDTree
[docs]
def subsample_points_by_radius(x: torch.Tensor,
radius: float) -> torch.LongTensor:
"""
Subsample points by radius. This is used in simulating mean curvature flow.
Args:
x (torch.Tensor): Points to subsample
radius (float): Radius for subsampling
Returns:
torch.LongTensor: Indices of subsampled points
"""
tree = KDTree(x.cpu().numpy())
inds = tree.query_ball_point(x.cpu().numpy(), r=radius)
lengths = np.array([len(i) for i in inds])
if (lengths == 1).any():
keepers = np.concatenate(inds[lengths == 1])
else:
keepers = np.empty((0))
if keepers.shape[0] == x.shape[0]:
return torch.LongTensor(keepers)
other_inds = np.unique(np.concatenate(inds[lengths > 1]))
other_x = x[other_inds]
otree = KDTree(other_x.cpu().numpy())
pd = otree.sparse_distance_matrix(otree, max_distance=radius).tocoo()
pd.data = pd.data.astype(bool)
pdmask = np.logical_not(pd.toarray())
mask = np.ones(pdmask.shape[0], dtype=bool)
arange = np.arange(mask.shape[0])
inds = []
while mask.any():
new_idx = np.random.choice(arange[mask])
inds.append(new_idx)
mask[new_idx] = False
mask = np.logical_and(mask, pdmask[new_idx])
keepers = np.concatenate((keepers, other_inds[inds]))
return torch.LongTensor(keepers)
[docs]
def smooth_values_by_gaussian(x: torch.Tensor,
values: torch.Tensor,
radius: float) -> torch.Tensor:
"""
Smooth values by a truncated Gaussian kernel over a given radius. This is used
in simulating mean curvature flow.
Args:
x (torch.Tensor): Input data points
values (torch.Tensor): Values to be smoothed
radius (float): Radius for truncating the Gaussian kernel. The standard deviation
of the Gaussian is set to one third of the radius.
Returns:
torch.Tensor: Smoothed values
"""
device = x.device
tree = KDTree(x.cpu().numpy())
k = max(
[len(i) for i in tree.query_ball_point(x.cpu().numpy(),
r=radius)]
)
dist, ind = tree.query(x.cpu().numpy(), k=k)
dist = torch.from_numpy(dist).float().to(device)
ind = torch.from_numpy(ind).long().to(device)
weights = torch.exp(-dist.pow(2) / (2 * (radius/3)**2))
weights[dist > radius] = 0
weights = weights / weights.sum(dim=1, keepdim=True)
values = (weights * values[ind]).sum(dim=1, keepdim=True)
return values