Source code for geotorch.fixedrank

import torch
from .lowrank import LowRank
from .exceptions import InverseError


def softplus_epsilon(x, epsilon=1e-6):
    return torch.nn.functional.softplus(x) + epsilon


def inv_softplus_epsilon(x, epsilon=1e-6):
    y = x - epsilon
    return torch.where(y > 20, y, y.expm1().log())


[docs]class FixedRank(LowRank): fs = {"softplus": (softplus_epsilon, inv_softplus_epsilon)} def __init__(self, size, rank, f="softplus", triv="expm"): r""" Manifold of non-square matrices of rank equal to ``rank`` Args: size (torch.size): Size of the tensor to be parametrized rank (int): Rank of the matrices. It has to be less or equal to :math:`\min(\texttt{size}[-1], \texttt{size}[-2])` f (str or callable or pair of callables): Optional. Either: - ``"softplus"`` - A callable that maps real numbers to the interval :math:`(0, \infty)` - A pair of callables such that the first maps the real numbers onto :math:`(0, \infty)` and the second is a (right) inverse of the first Default: ``"softplus"`` triv (str or callable): Optional. A map that maps skew-symmetric matrices onto the orthogonal matrices surjectively. This is used to optimize the :math:`U` and :math:`V` in the SVD. It can be one of ``["expm", "cayley"]`` or a custom callable. Default: ``"expm"`` """ super().__init__(size, rank, triv=triv) f, inv = FixedRank.parse_f(f) self.f = f self.inv = inv @staticmethod def parse_f(f): if f in FixedRank.fs.keys(): return FixedRank.fs[f] elif callable(f): return f, None elif isinstance(f, tuple) and callable(f[0]) and callable(f[1]): return f else: raise ValueError( "Argument f was not recognized and is " "not callable or a pair of callables. " "Should be one of {}. Found {}".format(list(FixedRank.fs.keys()), f) ) def submersion(self, U, S, V): return super().submersion(U, self.f(S), V) def submersion_inv(self, X, check_in_manifold=True): U, S, V = super().submersion_inv(X, check_in_manifold) if self.inv is None: raise InverseError(self) return U, self.inv(S), V def in_manifold_singular_values(self, S, eps=1e-5): r""" Checks that a vector of singular values is in the manifold. For tensors with more than 1 dimension the first dimensions are treated as batch dimensions. Args: S (torch.Tensor): Vector of singular values eps (float): Optional. Threshold at which the singular values are considered to be zero Default: ``1e-5`` """ if not super().in_manifold_singular_values(S, eps): return False # We compute the \infty-norm of the eigenvalues D = S[..., : self.rank] infty_norm = D.abs().max(dim=-1).values return (infty_norm > eps).all().item()
[docs] def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): r""" Returns a randomly sampled matrix on the manifold by sampling a matrix according to ``init_`` and projecting it onto the manifold. If the sampled matrix has more than `self.rank` small singular values, the smallest ones are clamped to be at least ``eps`` in absolute value. The output of this method can be used to initialize a parametrized tensor that has been parametrized with this or any other manifold as:: >>> layer = nn.Linear(20, 20) >>> M = FixedRank(layer.weight.size(), rank=6) >>> geotorch.register_parametrization(layer, "weight", M) >>> layer.weight = M.sample() Args: init\_ (callable): Optional. A function that takes a tensor and fills it in place according to some distribution. See `torch.init <https://pytorch.org/docs/stable/nn.init.html>`_. Default: ``torch.nn.init.xavier_normal_`` eps (float): Optional. Minimum singular value of the sampled matrix. Default: ``5e-6`` """ U, S, V = super().sample(factorized=True, init_=init_) with torch.no_grad(): # S >= 0, as given by torch.linalg.eigvalsh() S[S < eps] = eps X = (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) if self.transposed: X = X.transpose(-2, -1) return X