Source code for gnp.dataset.patch

import numpy as np
import torch
from typing import Optional
import torch_geometric as tg
from scipy.spatial import KDTree
from torch_geometric.data import Batch, Data
from torch_geometric.nn import radius_graph

from .pca import PCABatch

[docs] def graph_from_dict(sample: dict, radius: float, max_num_neighbors: int=32): """ Construct a graph from a dictionary of node features and other data. Parameters ---------- sample : dict Dictionary containing the node features and other data. radius : float Radius for the radius graph. max_num_neighbors : int, optional Maximum number of neighbors per node. Defaults to 32. Returns ------- torch_geometric.data.Data A graph data object containing node features, edge indices, and edge attributes. """ x = sample.get('x') x_device = x.device edge_data = sample.get('edge_data') if edge_data is None: edge_data = x tree = KDTree(x.cpu().detach().numpy()) cols = tree.query_ball_point(x.cpu().detach().numpy(), r=radius) cols = [ np.random.choice(c, size=min(max_num_neighbors, len(c)), replace=False) for c in cols ] cols = [torch.tensor(c).long() for c in cols] rows = [torch.full_like(c, i) for i, c in enumerate(cols)] edge_ind = torch.stack((torch.cat(rows), torch.cat(cols)), dim=0) edge_attr = edge_data[edge_ind.t()].reshape(-1, 2 * edge_data.shape[1]) data = Data(x=x, edge_index=edge_ind, edge_attr=edge_attr) for k, v in sample.items(): if k not in ['x', 'edge_data']: setattr(data, k, v) return data.to(x_device)
[docs] class PatchGenerator: """ PatchLoader class that generates patches from a point cloud. Parameters ---------- data : dict Point cloud data. Must include keys 'x' for point cloud and 'normals' which must have the correct orientation of the pointcloud. The normals do not need to be accurate other than the orientation. graph_radius : float Radius for graph construction. batch_size : int, optional Batch size. Defaults to 1. center : str, optional Method for choosing patch centers. If 'tree', a KDTree is used to ensure patches completely cover the point cloud. If 'gmls', there is one patch for every point in the data. Defaults to 'tree'. shuffle_patches : bool, optional Whether to shuffle patches. Defaults to True. knn : int, optional Number of neighbors to choose for patch construction. Defaults to 50. min_radius : float, optional Minimum patch radius. Defaults to 0.01. pca : bool, optional If True, returns patch in local PCA coordinates. Defaults to True. degree : int, optional Maximum degree of 1D Legendre basis if pca is True. Defaults to 3. orientation : Optional[torch.Tensor], optional Orientation used if pca is True. Defaults to None. scale : bool, optional If True, scaling is used if pca is also True. Defaults to True. include_scale : bool, optional If True, includes scale in the output. Defaults to True. min_z_scale : float, optional Minimum scale used in pca rescaling. Defaults to 1e-3. max_num_neighbors : int, optional Maximum neighbors used in graph construction. Defaults to 32. device : str, optional Device to use. Defaults to 'cpu'. """ def __init__(self, data: dict, graph_radius: float, batch_size: int=1, center: str='tree', shuffle_patches: bool=True, knn: int=50, min_radius: float=0.01, pca: bool=True, degree: int=3, orientation: Optional[torch.Tensor]=None, min_z_scale: float=1e-3, max_num_neighbors: int=32, device: str ='cpu'): self.current_idx = 0 self.data = data.copy() self.x = self.data.get('x').squeeze(0) self.graph_radius = graph_radius self.batch_size = batch_size self.shuffle = shuffle_patches self.device = device self.tree = KDTree(self.x.cpu().numpy()) self.center = center self.knn = knn self.min_radius = torch.tensor(min_radius, device=self.device) self.centers = self.get_centers() self.max_num_neighbors = max_num_neighbors if self.data.get('path') is not None: self.data.pop('path') if self.shuffle: self.order = torch.randperm(self.centers.shape[0]) else: self.order = torch.arange(self.centers.shape[0]) self.center_data = self.x[self.centers] self.num_patches = self.centers.shape[0] knn_dist, self.knn_ind = self.tree.query( self.center_data.cpu().numpy(), k=self.knn, eps=0.05 ) self.knn_dist = torch.from_numpy(knn_dist[:, -1]).to(self.device) if not hasattr(self, 'clusters'): self.clusters = tg.nn.knn(self.center_data, self.x, k=1)[1].squeeze() self.patch_idx, self.buffer_idx = self.get_patch_indices() self.pca = pca self.degree = degree self.min_z_scale = min_z_scale
[docs] def get_centers(self) -> torch.LongTensor: """ Choose the centers of the patches. If `self.center` is 'tree', the centers are chosen so that there is sufficient coverage of the data. If `self.center` is 'gmls', there is one patch center Returns ------- torch.LongTensor Indices of the patch centers. """ if self.center == 'tree': _, ind = self.tree.query(self.x.cpu().numpy(), k=int(self.knn//3)) centers = [] mask = np.ones(self.x.shape[0], dtype=bool) arange = np.arange(self.x.shape[0]) while mask.any(): idx = np.random.choice(arange[mask]) centers.append(idx) mask[ind[idx]] = False centers = torch.LongTensor(centers) knn_dist, self.knn_ind = self.tree.query(self.x.cpu()[centers].numpy(), k=self.knn) self.knn_dist = torch.from_numpy(knn_dist[:, -1]).to(self.device) clusters = torch.zeros(self.x.shape[0], dtype=torch.long) arange = torch.arange(self.knn_ind.shape[0]) for j in range(self.knn_ind.shape[1]-1, -1, -1): clusters[self.knn_ind[:, j]] = arange self.clusters = clusters.to(self.device) elif self.center == 'gmls': centers = torch.arange(self.x.shape[0]) knn_dist, self.knn_ind = self.tree.query(self.x.cpu().numpy(), k=self.knn) self.knn_dist = torch.from_numpy(knn_dist[:, -1]).to(self.device) self.clusters = centers.clone().to(self.device) return centers
[docs] def get_patch_indices(self) -> tuple[list[int], list[int]]: """ Returns indices of the patches and buffer regions. Returns ------- tuple of list of int patch_idx : list of int List of indices of the patches. buffer_idx : list of int List of indices of the buffer regions. """ radii = torch.maximum(1.1 * self.knn_dist, self.min_radius).cpu().numpy() patch_idx = self.tree.query_ball_point(self.center_data.cpu().numpy(), r=radii, p=2, eps=0.01) buffer_idx = self.tree.query_ball_point(self.center_data.cpu().numpy(), r=1.25*radii, p=2, eps=0.01) return patch_idx, buffer_idx
[docs] def get_patch(self, idx: int) -> Data: """ Returns a single patch in the form of a torch_geometric Data object. Parameters ---------- idx : int Index of the patch to return. Returns ------- Data torch_geometric Data object containing the patch and all the associated data. """ eval_patch_idx = torch.tensor(self.patch_idx[idx], device=self.device) buffer_patch_idx = torch.tensor(self.buffer_idx[idx], device=self.device) extra_patch_idx = torch.tensor([x for x in buffer_patch_idx if x not in eval_patch_idx], device=self.device) patch_idx = torch.cat((eval_patch_idx, extra_patch_idx)).long() cluster_mask = ( self.clusters[patch_idx] == self.clusters[self.centers[idx]] ).view(-1, 1) mask = torch.cat((torch.ones_like(eval_patch_idx), torch.zeros_like(extra_patch_idx)), dim=0).bool().view(-1, 1) patch_data = {} for k, v in self.data.items(): patch_data[k] = v.squeeze(0)[patch_idx] data = graph_from_dict(patch_data, radius=self.graph_radius, max_num_neighbors=self.max_num_neighbors) data.mask = mask data.cluster_mask = cluster_mask data.ind = patch_idx return data.to(self.device)
def __iter__(self): return self def __next__(self): if self.current_idx >= self.order.shape[0]: raise StopIteration else: idx = self.current_idx diff = min(self.order.shape[0] - self.current_idx, self.batch_size) self.current_idx += diff batch = Batch.from_data_list([self.get_patch(self.order[idx + i]) for i in range(diff)]) if self.pca: pca_mapper = PCABatch(num_graphs=batch.num_graphs, degree=self.degree, min_z_scale=self.min_z_scale) return pca_mapper.to_pca(batch) else: return batch
[docs] class PatchLoader: """ PatchLoader class that outputs a PatchGenerator instance when given PCD data. Parameters ---------- graph_radius : float Radius for graph construction. batch_size : int, optional Batch size. Defaults to 1. center : str, optional Method for choosing patch centers. If 'tree', a KDTree is used to ensure patches completely cover the point cloud. If 'gmls', there is one patch for every point in the data. Defaults to 'tree'. shuffle_patches : bool, optional Whether to shuffle patches. Defaults to True. knn : int, optional Number of neighbors to choose for patch construction. Defaults to 50. min_radius : float, optional Minimum patch radius. Defaults to 0.01. pca : bool, optional If True, returns patch in local PCA coordinates. Defaults to True. degree : int, optional Maximum degree of 1D Legendre basis if pca is True. Defaults to 3. min_z_scale : float, optional Minimum scale used in pca rescaling. Defaults to 1e-3. max_num_neighbors : int, optional Maximum neighbors used in graph construction. Defaults to 32. device : str, optional Device to use. Defaults to 'cpu'. """ def __init__(self, graph_radius: float, batch_size: int=1, center: str='tree', shuffle_patches: bool=True, knn: int=50, min_radius: float=0.01, pca: bool=True, degree: int=3, min_z_scale: float=1e-3, max_num_neighbors: int=32, device: str ='cpu'): self.graph_radius = graph_radius self.batch_size = batch_size self.shuffle = shuffle_patches self.device = device self.center = center self.knn = knn self.min_radius = min_radius self.pca = pca self.degree = degree self.min_z_scale = min_z_scale self.max_num_neighbors = max_num_neighbors def __call__(self, data: dict) -> PatchGenerator: """ Returns a PatchGenerator object. Parameters ---------- data : dict Point cloud data. Must include keys 'x' for point cloud and 'normals' which must have the correct orientation of the pointcloud. The normals do not need to be accurate other than the orientation. Returns ------- PatchGenerator PatchGenerator object. """ return PatchGenerator(data=data, graph_radius=self.graph_radius, batch_size=self.batch_size, shuffle_patches=self.shuffle, center=self.center, knn=self.knn, min_radius=self.min_radius, pca=self.pca, degree=self.degree, min_z_scale=self.min_z_scale, max_num_neighbors=self.max_num_neighbors, device=self.device)