Source code for gnp.geometry.surface

import torch
from torch_geometric.data import Batch
from functools import cached_property

from .legendre import Legendre2D

[docs] class SurfacePatch: """ Surface Patch for computing geometric quantities from Legendre Coefficients """ def __init__(self, patch_data: Batch, surface_coefficients: torch.Tensor): """ Parameters ---------- patch_data : Data Patch data in local PCA coordinates. surface_coefficients : torch.Tensor Legendre coefficients representing the surface. """ self.patch = patch_data self.x = patch_data.x self.device = patch_data.x.device self.batch = patch_data.batch self.mask = patch_data.mask.flatten() self.cluster_mask = patch_data.cluster_mask.flatten() self.coefficients = surface_coefficients self.basis = Legendre2D(degree=3) self.centers = patch_data.center self.xy_scale = patch_data.xy_scale[patch_data.batch].unsqueeze(1) self.z_scale = patch_data.z_scale.unsqueeze(1) self.pca_vectors = patch_data.pca_vectors.view(-1, 3, 3) self.derivative_scale = torch.reciprocal( torch.cat((self.xy_scale.repeat(1, 2), self.xy_scale.pow(2).repeat(1, 3)), dim=1) ) self.h = self.basis.evaluate_from_coeffs(self.patch.x[..., :2], self.z_scale * self.coefficients, self.batch) h_derivatives = self.basis.derivatives_from_coeffs(self.patch.x[..., :2], self.z_scale * self.coefficients, self.batch) h_derivatives = self.derivative_scale * h_derivatives (self.h_u, self.h_v, self.h_uv, self.h_uu, self.h_vv) = torch.split(h_derivatives, 1, 1)
[docs] @cached_property def local_coordinate_basis(self) -> torch.Tensor: """ Return the local PCA basis coordinates at each point Returns ------- torch.Tensor local coordinates of the surface patches """ return self.pca_vectors[self.batch]
[docs] @cached_property def xyz_coordinates(self) -> torch.Tensor: """ Compute the xyz coordinates of the surface patches Returns ------- torch.Tensor xyz coordinates of the surface patches """ pca_x = self.patch.x[..., :2].unsqueeze(2) pca_basis = self.local_coordinate_basis xyz = (self.centers[self.batch] + self.xy_scale * (pca_x * pca_basis[:, :2]).sum(dim=1) + self.h * pca_basis[:, 2]) return xyz
[docs] @cached_property def pca_coordinates(self) -> torch.Tensor: """ Compute the PCA coordinates of the surface patches Returns ------- torch.Tensor PCA coordinates of the surface patches """ pca_coord = torch.cat( (self.patch.x[..., :2], self.h / self.z_scale[self.batch]), dim=1 ) return pca_coord
[docs] @cached_property def tangents_pca(self) -> torch.Tensor: """ Compute the tangents of the surface patches in local coordinates Returns ------- torch.Tensor tangents of the surface patches """ ones = torch.ones((self.h.shape[0], 1), device=self.device) zeros = torch.zeros((self.h.shape[0], 1), device=self.device) tangent_u = torch.cat((ones, zeros, self.h_u), dim=-1) tangent_v = torch.cat((zeros, ones, self.h_v), dim=-1) return torch.stack((tangent_u, tangent_v), dim=1)
[docs] @cached_property def tangents(self) -> torch.Tensor: """ Compute the tangents of the surface patches in global coordinates Returns ------- torch.Tensor tangents of the surface patches """ pca_basis = self.local_coordinate_basis return self.tangents_pca @ pca_basis
[docs] @cached_property def normals_pca(self) -> torch.Tensor: """ Compute the normals of the surface patches in local coordinates Returns ------- torch.Tensor normals of the surface patches """ ones = torch.ones((self.h.shape[0], 1), device=self.device) normals = torch.cat((- self.h_u, - self.h_v, ones), dim=-1) normals = normals / (1 + self.h_u.pow(2) + self.h_v.pow(2)).sqrt() return normals
[docs] @cached_property def normals(self) -> torch.Tensor: """ Compute the normals of the surface patches in global coordinates Returns ------- torch.Tensor normals of the surface patches """ pca_basis = self.local_coordinate_basis return (self.normals_pca.unsqueeze(1) @ pca_basis).squeeze(1)
[docs] @cached_property def metric(self) -> torch.Tensor: """ Compute the metric tensor of the surface patches Returns ------- torch.Tensor metric tensor of the surface patches """ E = 1 + self.h_u.pow(2) F = self.h_u * self.h_v G = 1 + self.h_v.pow(2) return torch.stack((torch.cat((E, F), dim=1), torch.cat((F, G), dim=1)), dim=1)
[docs] @cached_property def shape(self) -> torch.Tensor: """ Compute the shape tensor of the surface patches Returns ------- torch.Tensor shape tensor of the surface patches """ divisor = (1 + self.h_u.pow(2) + self.h_v.pow(2)).sqrt() L = self.h_uu / divisor M = self.h_uv / divisor N = self.h_vv / divisor return torch.stack((torch.cat((L, M), dim=1), torch.cat((M, N), dim=1)), dim=1)
[docs] @cached_property def weingarten(self) -> torch.Tensor: """ Compute the Weingarten tensor of the surface patches Returns ------- torch.Tensor Weingarten tensor of the surface patches """ return torch.linalg.inv(self.metric) @ self.shape
[docs] @cached_property def gaussian_curvature(self) -> torch.Tensor: """ Compute the Gaussian curvature tensor of the surface patches Returns ------- torch.Tensor Gaussian curvature tensor of the surface patches """ return torch.linalg.det(self.weingarten)
[docs] @cached_property def mean_curvature(self) -> torch.Tensor: """ Compute the mean curvature tensor of the surface patches Returns ------- torch.Tensor mean curvature tensor of the surface patches """ return 0.5 * self.weingarten.diagonal(dim1=1, dim2=2).sum(dim=1)
[docs] @cached_property def inverse_metric(self) -> torch.Tensor: """ Compute the inverse metric tensor of the surface patches Returns ------- torch.Tensor inverse metric tensor of the surface patches """ return torch.linalg.inv(self.metric)
[docs] @cached_property def inverse_metric_derivatives(self) -> torch.Tensor: """ Compute the derivatives of the inverse metric tensor of the surface patches Returns ------- torch.Tensor derivatives of the inverse metric tensor of the surface patches """ zero = torch.zeros_like(self.h_u, device=self.device) mat1 = torch.cat( (1 + self.h_v.pow(2), -self.h_u * self.h_v, -self.h_u * self.h_v, 1 + self.h_u.pow(2)), dim=1).view(-1, 2, 2) mat2_u = torch.cat( (zero, -self.h_uu * self.h_v - self. h_u * self.h_uv, -self.h_uu * self.h_v - self.h_u * self.h_uv, 2 * self.h_u * self.h_uu), dim=1).view(-1, 2, 2) mat2_v = torch.cat( (2 * self.h_v * self.h_vv, -self.h_uv * self.h_v - self.h_u*self.h_vv, -self.h_uv * self.h_v - self.h_u * self.h_vv, zero), dim=1).view(-1, 2, 2) divisor = 1 + self.h_u.pow(2) + self.h_v.pow(2) det_u = (-(2 * self.h_u * self.h_uu + 2 * self.h_v * self.h_uv) / divisor.pow(2)).unsqueeze(2) det_v = (-(2 * self.h_u * self.h_uv + 2 * self.h_v * self.h_vv) / divisor.pow(2)).unsqueeze(2) ginv_u = det_u * mat1 + mat2_u / divisor.unsqueeze(2) ginv_v = det_v * mat1 + mat2_v / divisor.unsqueeze(2) return torch.stack((ginv_u, ginv_v), dim=-1)
[docs] @cached_property def det_metric(self) -> torch.Tensor: """ Compute the determinant of the metric tensor of the surface patches Returns ------- torch.Tensor determinant of the metric tensor of the surface patches """ return 1 + self.h_u.pow(2) + self.h_v.pow(2)
[docs] @cached_property def laplace_beltrami_first_terms(self) -> torch.Tensor: """ Compute the first terms of the Laplace-Beltrami operator. These terms are multiplied to the second derivatives of the function. Returns: torch.Tensor: The first terms of the laplace-beltrami operator. """ return self.det_metric.sqrt() * self.inverse_metric.contiguous().view(-1, 4)
[docs] @cached_property def laplace_beltrami_second_terms(self) -> torch.Tensor: """ Compute the second terms of the Laplace-Beltrami operator. These terms are multiplied to the first derivatives of the function. Returns: torch.Tensor: The second terms of the laplace-beltrami operator. """ det_g = self.det_metric uu = ((2 * self.h_v * self.h_uv / det_g.sqrt()) - ((self.h_u * self.h_uu + self.h_v * self.h_uv) * (1 + self.h_v.pow(2)) / det_g.pow(1.5))) uv = (- ((self.h_uu * self.h_v + self.h_u * self.h_uv) / det_g.pow(0.5)) + ((self.h_u * self.h_uu + self.h_v * self.h_uv) * (self.h_u * self.h_v) / det_g.pow(1.5))) vu = (- ((self.h_uv * self.h_v + self.h_u * self.h_vv) / det_g.pow(0.5)) + ((self.h_u * self.h_uv + self.h_v * self.h_vv) * (self.h_u * self.h_v) / det_g.pow(1.5))) vv = ((2 * self.h_u * self.h_uv / det_g.sqrt()) - ((self.h_u * self.h_uv + self.h_v * self.h_vv) * (1 + self.h_u.pow(2)) / det_g.pow(1.5))) return torch.cat((uu, uv, vu, vv), dim=1)
[docs] @cached_property def laplace_beltrami_legendre_basis(self) -> torch.Tensor: """Compute the Laplace-Beltrami operator on each of the Legendre basis functions. Returns: torch.Tensor: (n, 16) tensor containing the Laplace-Beltrami operator applied to the Legendre basis functions. """ derivatives = self.derivative_scale.unsqueeze(-1) * torch.stack( self.basis.evaluate_derivatives(self.patch.x[..., :2]), dim=1 ) fu, fv, fuv, fuu, fvv = torch.unbind(derivatives, dim=1) lb_first = self.laplace_beltrami_first_terms.unsqueeze(-1) lb_second = self.laplace_beltrami_second_terms.unsqueeze(-1) legendre_lb = (lb_first * torch.stack((fuu, fuv, fuv, fvv), dim=1) + lb_second * torch.stack((fu, fv, fu, fv), dim=1)).sum(dim=1) return legendre_lb / self.det_metric.sqrt()
[docs] def laplace_beltrami_from_coefficients(self, function_coefficients: torch.Tensor ) -> torch.Tensor: """ Compute the Laplace-Beltrami operator from the surface coefficients and the function coefficients. Parameters ---------- function_coefficients : torch.Tensor Legendre coefficients representing the function. Returns ------- torch.Tensor Laplace-Beltrami operator applied to the function. """ lb_values = (function_coefficients[self.batch] * self.laplace_beltrami_legendre_basis).sum(dim=1, keepdim=True) return lb_values