Source code for geotorch.almostorthogonal

import torch
from .lowrank import LowRank
from .exceptions import VectorError, InManifoldError, InverseError
from .utils import _extra_repr

def scaled_sigmoid(t):
    return 2.0 * (torch.sigmoid(t) - 0.5)

def inv_scaled_sigmoid(t):
    y = 0.5 * t + 0.5
    return torch.log(y / (1.0 - y))

[docs]class AlmostOrthogonal(LowRank): fs = { "scaled_sigmoid": (scaled_sigmoid, inv_scaled_sigmoid), "tanh": (torch.tanh, torch.atanh), "sin": (torch.sin, torch.asin), } def __init__(self, size, lam, f="sin", triv="expm"): r"""Manifold of matrices with singular values in the interval :math:`(1-\lambda, 1+\lambda)`. The possible default maps are the :math:`\sin,\,\tanh` functions and a scaled sigmoid. The sigmoid is scaled as :math:`\operatorname{scaled\_sigmoid}(x) = 2\sigma(x) - 1` where :math:`\sigma` is the usual sigmoid function. This is done so that the image of the scaled sigmoid is :math:`(-1, 1)`. Args: size (torch.size): Size of the tensor to be parametrized lam (float): Radius of the interval. A float in the interval :math:`(0, 1]` f (str or callable or pair of callables): Optional. Either: - One of ``["scaled_sigmoid", "tanh", "sin"]`` - A callable that maps real numbers to the interval :math:`(-1, 1)` - A pair of callables such that the first maps the real numbers to :math:`(-1, 1)` and the second is a (right) inverse of the first Default: ``"sin"`` triv (str or callable): Optional. A map that maps skew-symmetric matrices onto the orthogonal matrices surjectively. This is used to optimize the :math:`U` and :math:`V` in the SVD. It can be one of ``["expm", "cayley"]`` or a custom callable. Default: ``"expm"`` """ super().__init__(size, AlmostOrthogonal.rank(size), triv=triv) if lam < 0.0 or lam > 1.0: raise ValueError("The radius has to be between 0 and 1. Got {}".format(lam)) self.lam = lam f, inv = AlmostOrthogonal.parse_f(f) self.f = f self.inv = inv @staticmethod def parse_f(f): if f in AlmostOrthogonal.fs.keys(): return AlmostOrthogonal.fs[f] elif callable(f): return f, None elif isinstance(f, tuple) and callable(f[0]) and callable(f[1]): return f else: raise ValueError( "Argument f was not recognized and is " "not callable or a pair of callables. " "Should be one of {}. Found {}".format( list(AlmostOrthogonal.fs.keys()), f ) ) @classmethod def rank(cls, size): if len(size) < 2: raise VectorError(cls.__name__, size) return min(*size[-2:]) def submersion(self, U, S, V): S = 1.0 + self.lam * self.f(S) return super().submersion(U, S, V) def submersion_inv(self, X, check_in_manifold=True): if self.inv is None: raise InverseError(self) U, S, V = super().submersion_inv(X) if check_in_manifold and not self.in_manifold_singular_values(S): raise InManifoldError(X, self) # Harcoded epsilon... not a good practice if self.lam < 1e-6: S = S - 1.0 else: S = self.inv((S - 1.0) / self.lam) return U, S, V def in_manifold_singular_values(self, S, eps=1e-5): lam = self.lam if self.lam <= eps: lam = eps return ( super().in_manifold_singular_values(S, eps) and ((S - 1.0).abs() <= lam).all().item() )
[docs] def sample(self, distribution="uniform", init_=None): r""" Returns a randomly sampled orthogonal matrix according to the specified ``distribution``. The options are: - ``"uniform"``: Samples a tensor distributed according to the Haar measure on :math:`\operatorname{SO}(n)` - ``"torus"``: Samples a block-diagonal skew-symmetric matrix. The blocks are of the form :math:`\begin{pmatrix} 0 & b \\ -b & 0\end{pmatrix}` where :math:`b` is distributed according to ``init_``. This matrix will be then projected onto :math:`\operatorname{SO}(n)` using ``self.triv`` .. note The ``"torus"`` initialization is particularly useful in recurrent kernels of RNNs The output of this method can be used to initialize a parametrized tensor that has been parametrized with this or any other manifold as:: >>> layer = nn.Linear(20, 20) >>> M = AlmostOrthogonal(layer.weight.size(), lam=0.5) >>> geotorch.register_parametrization(layer, "weight", M) >>> layer.weight = M.sample() Args: distribution (string): Optional. One of ``["uniform", "torus"]``. Default: ``"uniform"`` init\_ (callable): Optional. To be used with the ``"torus"`` option. A function that takes a tensor and fills it in place according to some distribution. See `torch.init <>`_. Default: :math:`\operatorname{Uniform}(-\pi, \pi)` """ # Sample an orthogonal matrix as U and return it return self[0].sample(distribution=distribution, init_=init_)
def extra_repr(self): return _extra_repr( n=self.n, lam=self.lam, tensorial_size=self.tensorial_size, f=self.f, no_inv=self.inv is None, )