Source code for geotorch.pssdfixedrank

import torch

from .symmetric import SymF
from .fixedrank import softplus_epsilon, inv_softplus_epsilon

[docs]class PSSDFixedRank(SymF): fs = {"softplus": (softplus_epsilon, inv_softplus_epsilon)} def __init__(self, size, rank, f="softplus", triv="expm"): r""" Manifold of symmetric positive semidefinite matrices of rank :math:`r`. Args: size (torch.size): Size of the tensor to be parametrized rank (int): Rank of the matrices. It has to be less or equal to :math:`\min(\texttt{size}[-1], \texttt{size}[-2])` f (str or callable or pair of callables): Optional. Either: - ``"softplus"`` - A callable that maps real numbers to the interval :math:`(0, \infty)` - A pair of callables such that the first maps the real numbers to :math:`(0, \infty)` and the second is a (right) inverse of the first Default: ``"softplus"`` triv (str or callable): Optional. A map that maps skew-symmetric matrices onto the orthogonal matrices surjectively. This is used to optimize the :math:`Q` in the eigenvalue decomposition. It can be one of ``["expm", "cayley"]`` or a custom callable. Default: ``"expm"`` """ super().__init__(size, rank, PSSDFixedRank.parse_f(f), triv) @staticmethod def parse_f(f): if f in PSSDFixedRank.fs.keys(): return PSSDFixedRank.fs[f] elif callable(f): return f, None elif isinstance(f, tuple) and callable(f[0]) and callable(f[1]): return f else: raise ValueError( "Argument f was not recognized and is " "not callable or a pair of callables. " "Should be one of {}. Found {}".format(list(PSSDFixedRank.fs.keys()), f) ) def in_manifold_eigen(self, L, eps=1e-6): r""" Checks that an ascending ordered vector of eigenvalues is in the manifold. Args: L (torch.Tensor): Vector of eigenvalues of shape `(*, rank)` eps (float): Optional. Threshold at which the eigenvalues are considered to be zero Default: ``1e-6`` """ return ( super().in_manifold_eigen(L, eps) and (L[..., -self.rank :] >= eps).all().item() )
[docs] def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): r""" Returns a randomly sampled matrix on the manifold as .. math:: WW^\intercal \qquad W_{i,j} \sim \texttt{init_} If the sampled matrix has more than `self.rank` small singular values, the smallest ones are clamped to be at least ``eps`` in absolute value. The output of this method can be used to initialize a parametrized tensor as:: >>> layer = nn.Linear(20, 20) >>> M = PSSD(layer.weight.size()) >>> geotorch.register_parametrization(layer, "weight", M) >>> layer.weight = M.sample() Args: init\_ (callable): Optional. A function that takes a tensor and fills it in place according to some distribution. See `torch.init <>`_. Default: ``torch.nn.init.xavier_normal_`` eps (float): Optional. Minimum eigenvalue of the sampled matrix. Default: ``5e-6`` """ L, Q = super().sample(factorized=True, init_=init_) with torch.no_grad(): # L >= 0, as given by torch.linalg.eigvalsh() L[L < eps] = eps return (Q * L.unsqueeze(-2)) @ Q.transpose(-2, -1)