Source code for geotorch.grassmannian

import torch

from .stiefel import Stiefel


[docs]class Grassmannian(Stiefel): def __init__(self, size, triv="expm"): r""" Grassmannian manifold as a projection from the orthogonal matrices :math:`\operatorname{St}(n,k)`. The metric considered is the canonical. Args: size (torch.size): Size of the tensor to be parametrized triv (str or callable): Optional. A map that maps skew-symmetric matrices onto the orthogonal matrices surjectively. It can be one of ``["expm", "cayley"]`` or a custom callable. Default: ``"expm"`` """ super().__init__(size=size, triv=triv) def frame(self, X): k = X.size(-1) size_z = X.size()[:-2] + (k, k) Z = X.new_zeros(*size_z) X = torch.cat([Z, X[..., k:, :]], dim=-2) return super().frame(X)