import torch
from .utils import transpose, _extra_repr
from .so import SO, _has_orthonormal_columns
from .exceptions import VectorError, InManifoldError
[docs]class Stiefel(SO):
def __init__(self, size, triv="expm"):
r"""
Manifold of rectangular orthogonal matrices parametrized as a projection
onto the first :math:`k` columns from the space of square orthogonal matrices
:math:`\operatorname{SO}(n)`. The metric considered is the canonical.
Args:
size (torch.size): Size of the tensor to be parametrized
triv (str or callable): Optional.
A map that maps skew-symmetric matrices onto the orthogonal matrices
surjectively. It can be one of ``["expm", "cayley"]`` or a custom
callable. Default: ``"expm"``
"""
super().__init__(size=Stiefel.size_so(size), triv=triv, lower=True)
self.k = min(size[-1], size[-2])
self.transposed = size[-2] < size[-1]
@classmethod
def size_so(cls, size):
if len(size) < 2:
raise VectorError(cls.__name__, size)
size_so = list(size)
size_so[-1] = size_so[-2] = max(size[-1], size[-2])
return tuple(size_so)
def frame(self, X):
n, k = X.size(-2), X.size(-1)
size_z = X.size()[:-2] + (n, n - k)
return torch.cat([X, X.new_zeros(*size_z)], dim=-1)
@transpose
def forward(self, X):
X = self.frame(X)
X = super().forward(X)
return X[..., : self.k]
@transpose
def right_inverse(self, X, check_in_manifold=True):
if check_in_manifold and not self.in_manifold(X):
raise InManifoldError(X, self)
if self.n != self.k:
# N will be a completion of X to an orthogonal basis of R^n
N = X.new_empty(*(self.tensorial_size + (self.n, self.n - self.k)))
with torch.no_grad():
N.normal_()
# We assume for now that X is orthogonal.
# This will be checked in super().right_inverse()
# Project N onto the orthogonal complement to X
# We iterate this twice for this algorithm to be numerically stable
# This is standard, as done in some stochastic SVD algorithms
for _ in range(2):
N = N - X @ (X.transpose(-2, -1) @ N)
# And make it an orthonormal base of the image
N = torch.linalg.qr(N).Q
X = torch.cat([X, N], dim=-1)
return super().right_inverse(X, check_in_manifold=False)[..., : self.k]
[docs] def in_manifold(self, X, eps=1e-4):
r"""
Checks that a matrix is in the manifold.
For tensors with more than 2 dimensions the first dimensions are
treated as batch dimensions.
Args:
X (torch.Tensor): The matrix to be checked
eps (float): Optional. Tolerance to numerical errors.
Default: ``1e-4``
"""
if X.size(-1) > X.size(-2):
X = X.transpose(-2, -1)
if X.size() != self.tensorial_size + (self.n, self.k):
return False
return _has_orthonormal_columns(X, eps)
[docs] def sample(self, distribution="uniform", init_=None):
r"""
Returns a randomly sampled orthogonal matrix according to the specified
``distribution``. The options are:
- ``"uniform"``: Samples a tensor distributed according to the Haar measure
on :math:`\operatorname{SO}(n)`
- ``"torus"``: Samples a block-diagonal skew-symmetric matrix.
The blocks are of the form
:math:`\begin{pmatrix} 0 & b \\ -b & 0\end{pmatrix}` where :math:`b` is
distributed according to ``init_``. This matrix will be then projected
onto :math:`\operatorname{SO}(n)` using ``self.triv``
.. note
The ``"torus"`` initialization is particularly useful in recurrent kernels
of RNNs
Args:
distribution (string): Optional. One of ``["uniform", "torus"]``.
Default: ``"uniform"``
init\_ (callable): Optional. To be used with the ``"torus"`` option.
A function that takes a tensor and fills it in place according
to some distribution. See
`torch.init <https://pytorch.org/docs/stable/nn.init.html>`_.
Default: :math:`\operatorname{Uniform}(-\pi, \pi)`
"""
X = super().sample(distribution, init_)
if not self.transposed:
return X[..., : self.k]
else:
return X[..., : self.k, :]
def extra_repr(self):
return _extra_repr(
n=self.n,
k=self.k,
tensorial_size=self.tensorial_size,
triv=self.triv,
transposed=self.transposed,
)