Source code for gnp.dataset.pca

import torch
from torch_geometric.data import Data, Batch
from typing import Optional

[docs] class PCAPatch: def __init__(self, data: Optional[dict]=None, degree: int=3, min_z_scale: float=1e-3): self.data = data self.mask = None self.degree = degree self.basis = 'legendre' self.center = None self.pca_vectors = None self.xy_scale = None self.z_scale = None self.min_z_scale = min_z_scale if self.data is not None: self.initialize(data)
[docs] def initialize(self, data: dict): """ Initialize the PCA basis for the input patch Parameters ---------- data : dict Dictionary containing the input patch data """ self.data = data self.mask = data.mask.view(-1) patch_data = self.data.x[self.mask] self.center = patch_data.mean(dim=0, keepdim=True) _, S, Vh = torch.linalg.svd(patch_data - self.center, full_matrices=False) pca_vectors = Vh.clone() z_scale = S[2] * torch.ones(1, device=S.device) sign = [-1, 1][torch.randperm(2)[0]] z_scale *= sign if hasattr(data, 'normals'): self.orientation = data.normals.mean(dim=0, keepdim=True) self.orientation /= self.orientation.norm(dim=-1, keepdim=True) if self.orientation.isnan().any(): self.orientation = self.center else: self.orientation = self.center cross = torch.linalg.cross(Vh[0], Vh[1]) if (cross * self.orientation).sum() < 0: pca_vectors[0] = Vh[1].clone() pca_vectors[1] = Vh[0].clone() pca_vectors[2] = torch.linalg.cross(pca_vectors[0], pca_vectors[1]) self.pca_vectors = pca_vectors self.xy_scale = None self.z_scale = 2 * (z_scale / (patch_data.shape[0] - 1) ** 0.5) if self.z_scale.abs() < self.min_z_scale: self.z_scale = sign * self.min_z_scale * torch.ones(1, device=S.device)
[docs] def to_pca(self) -> Data: """ Perform change of coordinates to local PCA basis for input patch. Returns ------- Data Data object of patch in PCA basis. """ projection = self.pca_vectors.t() pca_coords = (self.data.x - self.center) @ projection self.xy_scale = pca_coords[self.mask, :2].norm(dim=-1).max() scaling = torch.tensor([self.xy_scale, self.xy_scale, self.z_scale], device=self.xy_scale.device) pca_coords = pca_coords / scaling.view(-1, 3) original_x = self.data.get('original_x', None) if original_x is not None: orig_pca_coords = ((original_x - self.center)
[docs] @ projection) / scaling.view(-1, 3) else: orig_pca_coords = pca_coords pca_edges = ((self.data.edge_attr - self.center.repeat(1, 2)) @ torch.block_diag(projection, projection)) pca_edges /= scaling.view(-1, 3).repeat(1, 2) normals = self.data.normals @ self.pca_vectors.t() out_data = Data( x=pca_coords, original_x=orig_pca_coords, normals=normals, edge_attr=pca_edges, mask=self.mask, degree=self.degree, xy_scale=self.xy_scale, z_scale=self.z_scale, pca_vectors=self.pca_vectors, center=self.center, basis=self.basis ) for k, v in self.data.items(): if k not in ['x', 'original_x', 'normals', 'mask', 'edge_attr']: out_data[k] = v return out_data
class PCABatch: """ Class for handling multiple PCA patches. """ def __init__(self, num_graphs: int, **kwargs): """ Initialize the PCABatch object Parameters ---------- num_graphs : int Number of PCA patches in the batch. """ self.num_graphs = num_graphs self.pcas = [PCAPatch(**kwargs) for _ in range(num_graphs)] self.normal_scale = 1
[docs] def to_pca(self, data: Batch) -> Batch: """ Convert the input batch to PCA basis. Parameters ---------- data : Batch Input batch of data. Returns ------- Batch Batch of data in PCA basis. """ for j in range(self.num_graphs): self.pcas[j].initialize(data[j]) batch = Batch.from_data_list([pca.to_pca() for pca in self.pcas]) return batch