# Source code for geotorch.sphere

import torch
from torch import nn

from .exceptions import InManifoldError
from .utils import _extra_repr

def project(x):
return x / x.norm(dim=-1, keepdim=True)

[docs]def uniform_init_sphere_(x, r=1.0):
r"""Samples a point uniformly on the sphere into the tensor x.
If x has :math:d > 1 dimensions, the first :math:d-1 dimensions
are treated as batch dimensions.
"""
x.normal_()
x.data = r * project(x)
return x

def _in_sphere(x, r, eps):
norm = x.norm(dim=-1)
rs = torch.full_like(norm, r)
return (torch.norm(norm - rs, p=float("inf")) < eps).all()

@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
# Hardocoded for float, will do for now
ret = torch.sin(x) / x
ret[x.abs() < 1e-45] = 1.0
return ret

@staticmethod
(x,) = ctx.saved_tensors
ret = torch.cos(x) / x - torch.sin(x) / (x * x)
ret[x.abs() < 1e-10] = 0.0

sinc = sinc_class.apply

[docs]class SphereEmbedded(nn.Module):
r"""
Sphere as the orthogonal projection from
:math:\mathbb{R}^n to :math:\mathbb{S}^{n-1}, that is,
:math:x \mapsto \frac{x}{\lVert x \rVert}.

Args:
size (torch.size): Size of the tensor to be parametrized
Radius of the sphere. A positive number. Default: 1.
"""
super().__init__()
self.n = size[-1]
self.tensorial_size = size[:-1]

@staticmethod
raise ValueError(
)

def forward(self, x):

def right_inverse(self, x, check_in_manifold=True):
if check_in_manifold and not self.in_manifold(x):
raise InManifoldError(x, self)

[docs]    def in_manifold(self, x, eps=1e-5):
r"""
Checks that a vector is on the sphere.

For tensors with more than 2 dimensions the first dimensions are
treated as batch dimensions.

Args:
X (torch.Tensor): The vector to be checked.
eps (float): Optional. Threshold at which the norm is considered
to be equal to 1. Default: 1e-5
"""

[docs]    def sample(self):
r"""
Returns a uniformly sampled vector on the sphere.
"""
x = torch.empty(*(self.tensorial_size) + (self.n,))

def extra_repr(self):
return _extra_repr(
)

[docs]class Sphere(nn.Module):
r"""
Sphere as a map from the tangent space onto the sphere using the
exponential map.

Args:
size (torch.size): Size of the tensor to be parametrized
Radius of the sphere. A positive number. Default: 1.
"""
super().__init__()
self.n = size[-1]
self.tensorial_size = size[:-1]
self.register_buffer("base", uniform_init_sphere_(torch.empty(*size)))

@staticmethod
raise ValueError(
)

def frame(self, x, v):
projection = (v.unsqueeze(-2) @ x.unsqueeze(-1)).squeeze(-1)
v = v - projection * x
return v

def forward(self, v):
x = self.base
# Project v onto {<x,v> = 0}
v = self.frame(x, v)
vnorm = v.norm(dim=-1, keepdim=True)
return self.radius * (torch.cos(vnorm) * x + sinc(vnorm) * v)

def right_inverse(self, x, check_in_manifold=True):
if check_in_manifold and not self.in_manifold(x):
raise InManifoldError(x, self)
self.base.copy_(x)

[docs]    def in_manifold(self, x, eps=1e-5):
r"""
Checks that a vector is on the sphere.

For tensors with more than 2 dimensions the first dimensions are
treated as batch dimensions.

Args:
X (torch.Tensor): The vector to be checked.
eps (float): Optional. Threshold at which the norm is considered
to be equal to 1. Default: 1e-5
"""

[docs]    def sample(self):
r"""
Returns a uniformly sampled vector on the sphere.
"""
device = self.base.device
dtype = self.base.dtype
x = torch.empty(*(self.tensorial_size) + (self.n,), device=device, dtype=dtype)