Source code for geotorch.lowrank

import torch

from .product import ProductManifold
from .stiefel import Stiefel
from .reals import Rn
from .exceptions import VectorError, RankError, InManifoldError
from .utils import transpose, _extra_repr


[docs]class LowRank(ProductManifold): def __init__(self, size, rank, triv="expm"): r""" Variety of the matrices of rank :math:`r` or less. Args: size (torch.size): Size of the tensor to be parametrized rank (int): Rank of the matrices. It has to be less or equal to :math:`\min(\texttt{size}[-1], \texttt{size}[-2])` triv (str or callable): Optional. A map that maps skew-symmetric matrices onto the orthogonal matrices surjectively. This is used to optimize the :math:`U` and :math:`V` in the SVD. It can be one of ``["expm", "cayley"]`` or a custom callable. Default: ``"expm"`` """ n, k, tensorial_size, transposed = LowRank.parse_size(size) if rank > min(n, k) or rank < 1: raise RankError(n, k, rank) super().__init__(LowRank.manifolds(n, k, rank, tensorial_size, triv)) self.n = n self.k = k self.rank = rank self.tensorial_size = tensorial_size self.transposed = transposed @classmethod def parse_size(cls, size): if len(size) < 2: raise VectorError(cls.__name__, size) transposed = size[-2] < size[-1] n = max(size[-2:]) k = min(size[-2:]) tensorial_size = size[:-2] return n, k, tensorial_size, transposed @staticmethod def manifolds(n, k, rank, tensorial_size, triv): size_u = tensorial_size + (n, rank) size_s = tensorial_size + (rank,) size_v = tensorial_size + (k, rank) return Stiefel(size_u, triv), Rn(size_s), Stiefel(size_v, triv) def frame(self, X): U = X.tril(-1)[..., : self.rank] S = X.diagonal(dim1=-2, dim2=-1)[..., : self.rank] V = X.triu(1).transpose(-2, -1)[..., : self.rank] return U, S, V def submersion(self, U, S, V): return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) @transpose def forward(self, X): X = self.frame(X) U, S, V = super().forward(X) return self.submersion(U, S, V) def frame_inv(self, X1, X2, X3): with torch.no_grad(): # X1 is lower-triangular # X2 is a vector # X3 is lower-triangular size = self.tensorial_size + (self.n, self.k) ret = torch.zeros(size, dtype=X1.dtype, device=X1.device) ret[..., : self.rank] += X1 ret[..., : self.rank, : self.rank] += torch.diag_embed(X2) ret.transpose(-2, -1)[..., : self.rank] += X3 return ret def submersion_inv(self, X, check_in_manifold=True): U, S, Vt = torch.linalg.svd(X, full_matrices=False) V = Vt.transpose(-2, -1) if check_in_manifold and not self.in_manifold_singular_values(S): raise InManifoldError(X, self) return U[..., : self.rank], S[..., : self.rank], V[..., : self.rank] @transpose def right_inverse(self, X, check_in_manifold=True): USV = self.submersion_inv(X, check_in_manifold) X1, X2, X3 = super().right_inverse(USV, check_in_manifold=False) return self.frame_inv(X1, X2, X3) def in_manifold_singular_values(self, S, eps=1e-5): r""" Checks that an ordered vector of singular values is in the manifold. For tensors with more than 1 dimension the first dimensions are treated as batch dimensions. Args: S (torch.Tensor): Vector of singular values eps (float): Optional. Threshold at which the singular values are considered to be zero Default: ``1e-5`` """ if S.size(-1) <= self.rank: return True # We compute the \infty-norm of the remaining dimension D = S[..., self.rank :] infty_norm_err = D.abs().max(dim=-1).values return (infty_norm_err < eps).all()
[docs] def in_manifold(self, X, eps=1e-5): r""" Checks that a given matrix is in the manifold. Args: X (torch.Tensor or tuple): The input matrix or matrices of shape ``(*, n, k)``. eps (float): Optional. Threshold at which the singular values are considered to be zero Default: ``1e-5`` """ if X.size(-1) > X.size(-2): X = X.transpose(-2, -1) if X.size() != self.tensorial_size + (self.n, self.k): return False S = torch.linalg.svdvals(X) return self.in_manifold_singular_values(S, eps)
[docs] def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False): r""" Returns a randomly sampled matrix on the manifold by sampling a matrix according to ``init_`` and projecting it onto the manifold. The output of this method can be used to initialize a parametrized tensor that has been parametrized with this or any other manifold as:: >>> layer = nn.Linear(20, 20) >>> M = LowRank(layer.weight.size(), rank=6) >>> geotorch.register_parametrization(layer, "weight", M) >>> layer.weight = M.sample() Args: init\_ (callable): Optional. A function that takes a tensor and fills it in place according to some distribution. See `torch.init <https://pytorch.org/docs/stable/nn.init.html>`_. Default: ``torch.nn.init.xavier_normal_`` """ with torch.no_grad(): device = self[0].base.device dtype = self[0].base.dtype X = torch.empty( *(self.tensorial_size + (self.n, self.k)), device=device, dtype=dtype ) init_(X) U, S, Vt = torch.linalg.svd(X, full_matrices=False) U, S, Vt = U[..., : self.rank], S[..., : self.rank], Vt[..., : self.rank, :] if factorized: return U, S, Vt.transpose(-2, -1) else: X = (U * S.unsqueeze(-2)) @ Vt if self.transposed: X = X.transpose(-2, -1) return X
def extra_repr(self): return _extra_repr( n=self.n, k=self.k, rank=self.rank, tensorial_size=self.tensorial_size, transposed=self.transposed, )