import torch
from .glp import GLp
from .fixedrank import FixedRank

[docs]class SL(GLp): def __init__(self, size, f="softplus", triv="expm"): r""" Manifold of special linear matrices Args: size (torch.size): Size of the tensor to be parametrized f (str or callable or pair of callables): Optional. Either: - ``"softplus"`` - A callable that maps real numbers to the interval :math:`(0, \infty)` - A pair of callables such that the first maps the real numbers to :math:`(0, \infty)` and the second is a (right) inverse of the first Default: ``"softplus"`` triv (str or callable): Optional. A map that maps skew-symmetric matrices onto the orthogonal matrices surjectively. This is used to optimize the :math:`U` and :math:`V` in the SVD. It can be one of ``["expm", "cayley"]`` or a custom callable. Default: ``"expm"`` """ super().__init__(size, SL.parse_f(f), triv) @staticmethod def parse_f(f_name): if f_name in FixedRank.fs.keys(): f, inv = FixedRank.parse_f(f_name) def f_sl(x): y = f(x) return y /, keepdim=True).pow(1.0 / y.shape[-1]) return (f_sl, inv) else: return f_name def in_manifold_singular_values(self, S, eps=5e-3): if not super().in_manifold_singular_values(S, eps): return False # We compute the \infty-norm of the determinant minus 1 and should be about zero infty_norm = ( - 1).abs().max(dim=-1).values return (infty_norm < eps).all().item()
[docs] def in_manifold(self, X, eps=5e-3): r""" Checks that a given matrix is in the manifold. Args: X (torch.Tensor or tuple): The input matrix or matrices of shape ``(*, n, k)``. eps (float): Optional. Threshold at which the singular values are considered to be zero Default: ``5e-3`` """ # The purpose of this function is just to have a more lax default eps value return super().in_manifold(X, eps)
[docs] def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False): r""" Returns a randomly sampled matrix on the manifold by sampling a matrix according to ``init_`` and projecting it onto the manifold. The output of this method can be used to initialize a parametrized tensor that has been parametrized with this or any other manifold as:: >>> layer = nn.Linear(20, 20) >>> M = SL(layer.weight.size(), rank=6) >>> geotorch.register_parametrization(layer, "weight", M) >>> layer.weight = M.sample() Args: init\_ (callable): Optional. A function that takes a tensor and fills it in place according to some distribution. See `torch.init <>`_. Default: ``torch.nn.init.xavier_normal_`` eps (float): Optional. Minimum singular value of the sampled matrix. Default: ``5e-6`` """ U, S, V = super().sample(factorized=True, init_=init_) with torch.no_grad(): # S >= 0, as given by torch.linalg.eigvalsh() S = S /, keepdim=True).pow(1.0 / S.shape[-1]) return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1)