Source code for gnp.estimator

import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from typing import Optional
from torch_geometric.data import Batch
from torch_geometric.nn import fps
from tqdm import tqdm
import scipy.sparse as sp

from .config import load_config, load_model, load_patchloader
from .geometry.surface import SurfacePatch
from .utils import smooth_values_by_gaussian, subsample_points_by_radius 

MODEL_PATH = Path.joinpath(Path(__file__).parent, 'model_weights')

[docs] class GeometryEstimator: """ Geometry estimator for a point cloud. Parameters ---------- pcd : torch.Tensor Point cloud data. orientation : torch.Tensor Orientation data. function_values : Optional[torch.Tensor] Function values for the point cloud. model : str Model name. device : torch.device Device to run the model on. data : dict Dictionary containing the input data. model_path : Path Path to the model state dictionary. model : nn.Module Loaded model. config_path : Path Path to the configuration file. cfg : dict Configuration dictionary. patch_loader : PatchLoader Patch loader object. """ def __init__(self, pcd: torch.Tensor, orientation: torch.Tensor, function_values: Optional[torch.Tensor]=None, model: str='clean_30k', device: torch.device='cpu', **patch_kwargs: Optional[dict]): assert model in ['clean_30k', 'clean_50k', 'noise_70k', 'outlier_50k'] self.pcd = pcd.to(device) self.orientation = orientation.to(device) self.device = device self.data = {'x': self.pcd, 'normals': self.orientation} self.model_path = Path.joinpath(MODEL_PATH, model, 'state_dict.pth') self.model = load_model(self.model_path) self.model.to(device) self.config_path = Path.joinpath(MODEL_PATH, model, 'config.yaml') self.cfg = load_config(self.config_path) self.cfg['device'] = device self.patch_loader = load_patchloader(self.cfg, **patch_kwargs) if function_values is not None: self.function_values = function_values.to(device) self.data['function_values'] = self.function_values
[docs] def surface_patch(self, patch_data: Batch) -> SurfacePatch: """ Create a SurfacePatch from the input patch data. Parameters ---------- patch_data : Batch Batch containing the input patch data. Returns ------- SurfacePatch SurfacePatch object. """ with torch.no_grad(): surface_coefficients = self.model(patch_data) return SurfacePatch(patch_data, surface_coefficients)
[docs] def estimate_quantities(self, scalar_names: list[str]) -> dict: """ Estimate geometric quantities on the point cloud. This function returns a dictionary containing the estimated scalar and/or vector values. Args: scalar_names (list[str]): List of scalar names to estimate. This can be any of the following: 'xyz_coordinates', 'normals', 'tangents', 'mean_curvature', 'gaussian_curvature', 'pca_coordinates', 'normals_pca', 'tangents_pca', 'metric', 'shape', 'weingarten', 'inverse_metric', 'inverse_metric_derivatives', 'det_metric', 'laplace_beltrami_from_coefficients' Returns: dict: Dictionary containing the estimated scalar values. """ data_size = self.pcd.size(0) patch_dataloader = self.patch_loader(self.data) output = {} for batch in patch_dataloader: mask = batch.cluster_mask.flatten() indices = batch.ind[mask] patch = self.surface_patch(batch) for name in scalar_names: quantity = getattr(patch, name, None) if quantity is None: raise ValueError(f'Unknown scalar name: {name}') if name not in output: output[name] = torch.zeros((data_size, *quantity.shape[1:]), device=self.device) output[name][indices] = quantity[mask] return output
[docs] def flow_step(self, delta_t: float, subsample_radius: float, smooth_radius: float, smooth_x: bool) -> dict: """ Perform a single step of mean curvature flow on the point cloud. Args: delta_t (float): Time step for the flow. subsample_radius (float): Radius used for subsampling points. smooth_radius (float): Radius used for smoothing mean curvature. smooth_x (bool): Whether to smooth the point cloud before flow. Returns: dict: Dictionary containing the update point cloud data, normals, and mean curvature. """ if smooth_x: estimate = self.estimate_quantities(['xyz_coordinates']) x = estimate['xyz_coordinates'] self.pcd = x self.data['x'] = x estimate = self.estimate_quantities(['normals', 'mean_curvature']) x = self.pcd normals = estimate['normals'] mean_curvature = smooth_values_by_gaussian(x=x, values=estimate['mean_curvature'], radius=smooth_radius) new_x = x + delta_t * mean_curvature.view(-1, 1) * normals subsampled_indices = subsample_points_by_radius(new_x, subsample_radius) new_x = new_x[subsampled_indices] new_normals = normals[subsampled_indices] mean_curvature = mean_curvature[subsampled_indices] new_data = {'x': new_x, 'normals': new_normals, 'mean_curvature': mean_curvature} return new_data
[docs] def mean_flow(self, num_steps: int, save_data_per_step: int, delta_t: float, subsample_radius: float, smooth_radius: float, smooth_x: bool) -> list[dict]: """ Perform mean curvature flow on the point cloud. Args: num_steps (int): Number of steps to perform. save_data_per_step (int): Save data every n steps. delta_t (float): Time step for the flow. subsample_radius (float): Radius used for subsampling points. smooth_radius (float): Radius used for smoothing mean curvature. smooth_x (bool): Whether to smooth the point cloud before flow. Returns: list[dict]: List of dictionaries containing the updated point cloud data and normals at each saved time step. """ save_data = [] for i in tqdm(range(num_steps)): new_data = self.flow_step(delta_t=delta_t, subsample_radius=subsample_radius, smooth_radius=smooth_radius, smooth_x=smooth_x) self.data = new_data.copy() self.pcd = new_data['x'] self.orientation = new_data['normals'] if i % save_data_per_step == 0: save_data.append(new_data.copy()) return save_data
[docs] def gmls_weights(self, batch: Batch, radius: float=1., p: int=4) -> list[torch.Tensor]: """ Compute the weights for the generalized moving least squares (GMLS) method. Args: batch (Batch): Batch containing the input data. radius (float, optional): Radius to truncate weight function. Defaults to 1.. p (int, optional): Degree p of the weight function. Defaults to 4. Returns: list[torch.Tensor]: List of weight matrices for each batch. """ xs = [ batch[i].x[batch[i].mask.flatten(), :2] for i in range(batch.num_graphs) ] centers = batch.x[batch.cluster_mask.flatten(), :2] cxs = [x - centers[i].view(1, 2) for i, x in enumerate(xs)] Ws = [ torch.diag( F.relu(1 - cxs[j].norm(dim=1) / radius).pow(p)) for j in range(batch.num_graphs) ] return Ws
[docs] def laplace_beltrami_legendre_blocks(self, surface: SurfacePatch ) -> list[torch.Tensor]: """Laplace-Beltrami operator of Legendre basis functions at each center point. Args: surface (SurfacePatch): Surface patch object. Returns: list[torch.Tensor]: List of outputs at each center point. """ lb_values = surface.laplace_beltrami_legendre_basis[surface.cluster_mask] return torch.split(lb_values, split_size_or_sections=1, dim=0)
[docs] def legendre_blocks(self, surface: SurfacePatch) -> list[torch.Tensor]: """Legendre basis functions blocked by batch. Args: surface (SurfacePatch): Surface patch object. Returns: list[torch.Tensor]: List of outputs per batch. """ mask = surface.mask batch = surface.batch[mask] legendre_values = surface.basis.evaluate(surface.x[mask, :2]) legendre_blocks = [ legendre_values[batch == i] for i in range(batch.max() + 1) ] return legendre_blocks
[docs] def stiffness_on_batch(self, batch: Batch, radius: float=1., p: int=4) -> tuple[torch.Tensor, torch.Tensor]: """ Compute the stiffness matrix on a batch of data. Args: batch (Batch): Batch containing the input data. radius (float, optional): Radius to truncate weight function. Defaults to 1.. p (int, optional): Degree p of the weight function. Defaults to 4. Returns: torch.Tensor: Indices and values of the stiffness matrix. """ Ws = self.gmls_weights(batch, radius=radius, p=p) surface = self.surface_patch(batch) legendre_blocks = self.legendre_blocks(surface) lb_blocks = self.laplace_beltrami_legendre_blocks(surface) outputs = [ torch.linalg.lstsq((leg.T @ W @ leg).cpu(), (leg.T @ W).cpu(), driver='gelsd') for leg, W in zip(legendre_blocks, Ws) ] stiffness_values = torch.cat([ -lb @ output.solution.to(self.device) for lb, output in zip(lb_blocks, outputs) ], dim=1).flatten() row_inds = batch.ind[surface.cluster_mask] column_inds = [ batch[i].ind[batch[i].mask].view(1, -1) for i in range(batch.num_graphs) ] stiffness_indices = torch.cat([ torch.cat( (row_inds[i] * torch.ones((1, column_inds[i].shape[1]), device=self.device), column_inds[i]), dim=0) for i in range(batch.num_graphs)], dim=1 ) return stiffness_indices, stiffness_values
[docs] def stiffness_matrix_gmls(self, drop_ratio: float=0.1, radius: float=1., p: int=4, remove_outliers: bool=False) -> sp.coo_matrix: """ Compute the stiffness matrix using the generalized moving least squares (GMLS) method. Args: drop_ratio (float, optional): Ratio of points to drop. Defaults to 0.1. radius (float, optional): Radius to truncate weight function. Defaults to 1.. p (int, optional): Degree p of the weight function. Defaults to 4. remove_outliers (bool, optional): Whether to remove outliers. Defaults to False. Returns: sp.coo_matrix: Stiffness matrix. """ if remove_outliers: outputs = self.estimate_quantities(['x', 'pca_coordinates']) outlier_mask = (outputs['x' - outputs['pca_coordinates']].norm(dim=1) < 0.1) for k, v in self.data.items(): self.data[k] = v[outlier_mask] self.pcd = self.pcd[outlier_mask] self.orientation = self.orientation[outlier_mask] drop_inds = fps(self.pcd, ratio=drop_ratio) self.data['stiffness_mask'] = torch.ones(self.pcd.size(0), dtype=torch.bool, device=self.device) self.data['stiffness_mask'][drop_inds] = False self.patch_loader.center = 'gmls' patch_dataloader = self.patch_loader(self.data) stiffness_indices = [] stiffness_values = [] for batch in tqdm(patch_dataloader): batch.mask = batch.mask & batch.stiffness_mask indices, values = self.stiffness_on_batch(batch, radius=radius, p=p) stiffness_indices.append(indices) stiffness_values.append(values) stiffness_indices = torch.cat(stiffness_indices, dim=1) stiffness_values = torch.cat(stiffness_values) notmask = torch.logical_not(self.data['stiffness_mask']) skipped = torch.arange(self.data['stiffness_mask'].shape[0], device=self.device)[notmask] if len(skipped) > 0: subtract_index_value = torch.zeros_like(stiffness_indices[1], device=self.device) num_iters = skipped.shape[0] / 10 num_iters = int(num_iters) if num_iters % 1 == 0 else int(num_iters) + 1 for i in range(num_iters): start = i * 10 end = min((i+1) * 10, skipped.shape[0]) subtract_index_value += ( stiffness_indices[1].view(-1, 1) > skipped[start:end].view(1, -1) ).sum(dim=1) stiffness_indices[1] = stiffness_indices[1] - subtract_index_value.long() stiffness = sp.coo_matrix(( stiffness_values.cpu().numpy(), stiffness_indices.cpu().numpy().astype(np.int32) )) return stiffness