# Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved.
from dataclasses import dataclass
from typing import Union
import numpy as np
import numpy.typing as ntp
import torch
from adam.core.spatial_math import ArrayLike, ArrayLikeFactory, SpatialMath
@dataclass
[docs]
class TorchLike(ArrayLike):
"""Class wrapping pyTorch types"""
[docs]
def __post_init__(self):
"""Converts array to the desired type"""
if self.array.dtype != torch.get_default_dtype():
self.array = self.array.to(torch.get_default_dtype())
[docs]
def __setitem__(self, idx, value: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides set item operator"""
if type(self) is type(value):
self.array[idx] = value.array.reshape(self.array[idx].shape)
else:
self.array[idx] = torch.tensor(value) if isinstance(value, float) else value
[docs]
def __getitem__(self, idx):
"""Overrides get item operator"""
return TorchLike(self.array[idx])
@property
[docs]
def shape(self):
return self.array.shape
[docs]
def reshape(self, *args):
return self.array.reshape(*args)
@property
[docs]
def T(self) -> "TorchLike":
"""
Returns:
TorchLike: transpose of array
"""
# check if self.array is a 0-D tensor
if len(self.array.shape) == 0:
return TorchLike(self.array)
x = self.array
return TorchLike(x.permute(*torch.arange(x.ndim - 1, -1, -1)))
[docs]
def __matmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides @ operator"""
if type(self) is type(other):
return TorchLike(self.array @ other.array)
if isinstance(other, torch.Tensor):
return TorchLike(self.array @ other)
else:
return TorchLike(self.array @ torch.tensor(other))
[docs]
def __rmatmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides @ operator"""
if type(self) is type(other):
return TorchLike(other.array @ self.array)
else:
return TorchLike(torch.tensor(other) @ self.array)
[docs]
def __mul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides * operator"""
if type(self) is type(other):
return TorchLike(self.array * other.array)
else:
return TorchLike(self.array * other)
[docs]
def __rmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides * operator"""
if type(self) is type(other):
return TorchLike(other.array * self.array)
else:
return TorchLike(other * self.array)
[docs]
def __truediv__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides / operator"""
if type(self) is type(other):
return TorchLike(self.array / other.array)
else:
return TorchLike(self.array / other)
[docs]
def __add__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides + operator"""
if type(self) is not type(other):
return TorchLike(self.array.squeeze() + other.squeeze())
return TorchLike(self.array.squeeze() + other.array.squeeze())
[docs]
def __radd__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides + operator"""
if type(self) is not type(other):
return TorchLike(self.array.squeeze() + other.squeeze())
return TorchLike(self.array.squeeze() + other.array.squeeze())
[docs]
def __sub__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides - operator"""
if type(self) is type(other):
return TorchLike(self.array.squeeze() - other.array.squeeze())
else:
return TorchLike(self.array.squeeze() - other.squeeze())
[docs]
def __rsub__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides - operator"""
if type(self) is type(other):
return TorchLike(other.array.squeeze() - self.array.squeeze())
else:
return TorchLike(other.squeeze() - self.array.squeeze())
[docs]
def __neg__(self) -> "TorchLike":
"""Overrides - operator"""
return TorchLike(-self.array)
[docs]
class TorchLikeFactory(ArrayLikeFactory):
@staticmethod
[docs]
def zeros(*x: int) -> "TorchLike":
"""
Returns:
TorchLike: zero matrix of dimension *x
"""
return TorchLike(torch.zeros(x))
@staticmethod
[docs]
def eye(x: int) -> "TorchLike":
"""
Args:
x (int): dimension
Returns:
TorchLike: identity matrix of dimension x
"""
return TorchLike(torch.eye(x))
@staticmethod
[docs]
def array(x: ntp.ArrayLike) -> "TorchLike":
"""
Returns:
TorchLike: vector wrapping x
"""
return TorchLike(torch.tensor(x))
[docs]
class SpatialMath(SpatialMath):
def __init__(self):
super().__init__(TorchLikeFactory())
@staticmethod
[docs]
def sin(x: ntp.ArrayLike) -> "TorchLike":
"""
Args:
x (ntp.ArrayLike): angle value
Returns:
TorchLike: sin value of x
"""
if isinstance(x, float):
x = torch.tensor(x)
return TorchLike(torch.sin(x))
@staticmethod
[docs]
def cos(x: ntp.ArrayLike) -> "TorchLike":
"""
Args:
x (ntp.ArrayLike): angle value
Returns:
TorchLike: cos value of x
"""
# transform to torch tensor, if not already
if isinstance(x, float):
x = torch.tensor(x)
return TorchLike(torch.cos(x))
@staticmethod
[docs]
def outer(x: ntp.ArrayLike, y: ntp.ArrayLike) -> "TorchLike":
"""
Args:
x (ntp.ArrayLike): vector
y (ntp.ArrayLike): vector
Returns:
TorchLike: outer product of x and y
"""
return TorchLike(torch.outer(torch.tensor(x), torch.tensor(y)))
@staticmethod
[docs]
def skew(x: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""
Args:
x (Union[TorchLike, ntp.ArrayLike]): vector
Returns:
TorchLike: skew matrix from x
"""
if not isinstance(x, TorchLike):
return TorchLike(
torch.tensor([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])
)
x = x.array
return TorchLike(
torch.tensor([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])
)
@staticmethod
[docs]
def vertcat(*x: ntp.ArrayLike) -> "TorchLike":
"""
Returns:
TorchLike: vertical concatenation of x
"""
if isinstance(x[0], TorchLike):
v = torch.vstack([x[i].array for i in range(len(x))])
else:
v = torch.tensor(x)
return TorchLike(v)
@staticmethod
[docs]
def horzcat(*x: ntp.ArrayLike) -> "TorchLike":
"""
Returns:
TorchLike: horizontal concatenation of x
"""
if isinstance(x[0], TorchLike):
v = torch.hstack([x[i].array for i in range(len(x))])
else:
v = torch.tensor(x)
return TorchLike(v)