Source code for geotorch.skew

import torch
from torch import nn
from .exceptions import VectorError, NonSquareError, InManifoldError


[docs] class Skew(nn.Module): def __init__(self, size, lower=True): r""" Vector space of skew-symmetric matrices, parametrized in terms of the upper or lower triangular part of a matrix. Args: size (torch.size): Size of the tensor to be parametrized lower (bool): Optional. Uses the lower triangular part of the matrix to parametrize the matrix. Default: ``True`` """ super().__init__() n, tensorial_size = Skew.parse_size(size) self.n = n self.tensorial_size = tensorial_size self.lower = lower self.register_buffer("_reference", torch.empty(0), persistent=False) @classmethod def parse_size(cls, size): if len(size) < 2: raise VectorError(cls.__name__, size) n, k = size[-2:] tensorial_size = size[:-2] if n != k: raise NonSquareError(cls.__name__, size) return n, tensorial_size @staticmethod def frame(X, lower): if lower: X = X.tril(-1) else: X = X.triu(1) return X - X.transpose(-2, -1) def forward(self, X): if len(X.size()) < 2: raise VectorError(type(self).__name__, X.size()) if X.size(-2) != X.size(-1): raise NonSquareError(type(self).__name__, X.size()) return self.frame(X, self.lower) def right_inverse(self, X, check_in_manifold=True, tol=1e-4): if check_in_manifold and not torch.allclose(X, -X.mT, atol=tol): raise InManifoldError(X, self) # We assume that X is skew_symmetric if self.lower: return X.tril(-1) else: return X.triu(1)
[docs] @staticmethod def in_manifold(X): return ( X.dim() >= 2 and X.size(-2) == X.size(-1) and torch.allclose(X, -X.transpose(-2, -1)) )
def sample(self, init_=nn.init.xavier_normal_, lower=True): r""" Returns a randomly sampled matrix on the manifold as .. math:: tril(W) \qquad W_{i,j} \sim \texttt{init_} if lower is set to True tiu(W) \qquad W_{i,j} \sim \texttt{init_} otherwise By default ``init\_`` is a (xavier) normal distribution, so that the returned matrix follows a Wishart distribution. The output of this method can be used to initialize a parametrized tensor that has been parametrized with this or any other manifold as:: >>> layer = nn.Linear(20, 20) >>> M = Skew(layer.weight.size()) >>> torch.nn.utils.parametrize.register_parametrization(layer, "weight", M) >>> layer.weight = M.sample() Args: init\_ (callable): Optional. A function that takes a tensor and fills it in place according to some distribution. See `torch.init <https://pytorch.org/docs/stable/nn.init.html>`_. Default: ``torch.nn.init.xavier_normal_`` """ with torch.no_grad(): X = self._reference.new_empty(self.tensorial_size + (self.n, self.n)) init_(X) if lower: X.tril_(-1) else: X.triu_(1) return X - X.mT