# Source code for geotorch.sl

import torch
from .glp import GLp
from .fixedrank import FixedRank

[docs]class SL(GLp):
def __init__(self, size, f="softplus", triv="expm"):
r"""
Manifold of special linear matrices

Args:
size (torch.size): Size of the tensor to be parametrized
f (str or callable or pair of callables): Optional. Either:

- "softplus"

- A callable that maps real numbers to the interval :math:(0, \infty)

- A pair of callables such that the first maps the real numbers to
:math:(0, \infty) and the second is a (right) inverse of the first

Default: "softplus"
triv (str or callable): Optional.
A map that maps skew-symmetric matrices onto the orthogonal matrices
surjectively. This is used to optimize the :math:U and :math:V in the
SVD. It can be one of ["expm", "cayley"] or a custom
callable. Default: "expm"
"""
super().__init__(size, SL.parse_f(f), triv)

@staticmethod
def parse_f(f_name):
if f_name in FixedRank.fs.keys():
f, inv = FixedRank.parse_f(f_name)

def f_sl(x):
y = f(x)
return y / y.prod(dim=-1, keepdim=True).pow(1.0 / y.shape[-1])

return (f_sl, inv)
else:
return f_name

def in_manifold_singular_values(self, S, eps=5e-3):
if not super().in_manifold_singular_values(S, eps):
return False
# We compute the \infty-norm of the determinant minus 1 and should be about zero
infty_norm = (S.prod(dim=-1) - 1).abs().max(dim=-1).values
return (infty_norm < eps).all().item()

[docs]    def in_manifold(self, X, eps=5e-3):
r"""
Checks that a given matrix is in the manifold.

Args:
X (torch.Tensor or tuple): The input matrix or matrices of shape (*, n, k).
eps (float): Optional. Threshold at which the singular values are
considered to be zero
Default: 5e-3
"""
# The purpose of this function is just to have a more lax default eps value
return super().in_manifold(X, eps)

[docs]    def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6, factorized=False):
r"""
Returns a randomly sampled matrix on the manifold by sampling a matrix according
to init_ and projecting it onto the manifold.

The output of this method can be used to initialize a parametrized tensor
that has been parametrized with this or any other manifold as::

>>> layer = nn.Linear(20, 20)
>>> M = SL(layer.weight.size(), rank=6)
>>> geotorch.register_parametrization(layer, "weight", M)
>>> layer.weight = M.sample()

Args:
init\_ (callable): Optional. A function that takes a tensor and fills it
in place according to some distribution. See
torch.init <https://pytorch.org/docs/stable/nn.init.html>_.
Default: torch.nn.init.xavier_normal_
eps (float): Optional. Minimum singular value of the sampled matrix.
Default: 5e-6
"""
U, S, V = super().sample(factorized=True, init_=init_)