Source code for adam.pytorch.torch_like

# Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved.

from dataclasses import dataclass
from typing import Union

import numpy as np
import numpy.typing as ntp
import torch

from adam.core.spatial_math import ArrayLike, ArrayLikeFactory, SpatialMath


@dataclass
[docs] class TorchLike(ArrayLike): """Class wrapping pyTorch types"""
[docs] array: torch.Tensor
[docs] def __post_init__(self): """Converts array to the desired type""" if self.array.dtype != torch.get_default_dtype(): self.array = self.array.to(torch.get_default_dtype())
[docs] def __setitem__(self, idx, value: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides set item operator""" if type(self) is type(value): self.array[idx] = value.array.reshape(self.array[idx].shape) else: self.array[idx] = torch.tensor(value) if isinstance(value, float) else value
[docs] def __getitem__(self, idx): """Overrides get item operator""" return TorchLike(self.array[idx])
@property
[docs] def shape(self): return self.array.shape
[docs] def reshape(self, *args): return self.array.reshape(*args)
@property
[docs] def T(self) -> "TorchLike": """ Returns: TorchLike: transpose of array """ # check if self.array is a 0-D tensor if len(self.array.shape) == 0: return TorchLike(self.array) x = self.array return TorchLike(x.permute(*torch.arange(x.ndim - 1, -1, -1)))
[docs] def __matmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides @ operator""" if type(self) is type(other): return TorchLike(self.array @ other.array) if isinstance(other, torch.Tensor): return TorchLike(self.array @ other) else: return TorchLike(self.array @ torch.tensor(other))
[docs] def __rmatmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides @ operator""" if type(self) is type(other): return TorchLike(other.array @ self.array) else: return TorchLike(torch.tensor(other) @ self.array)
[docs] def __mul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides * operator""" if type(self) is type(other): return TorchLike(self.array * other.array) else: return TorchLike(self.array * other)
[docs] def __rmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides * operator""" if type(self) is type(other): return TorchLike(other.array * self.array) else: return TorchLike(other * self.array)
[docs] def __truediv__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides / operator""" if type(self) is type(other): return TorchLike(self.array / other.array) else: return TorchLike(self.array / other)
[docs] def __add__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides + operator""" if type(self) is not type(other): return TorchLike(self.array.squeeze() + other.squeeze()) return TorchLike(self.array.squeeze() + other.array.squeeze())
[docs] def __radd__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides + operator""" if type(self) is not type(other): return TorchLike(self.array.squeeze() + other.squeeze()) return TorchLike(self.array.squeeze() + other.array.squeeze())
[docs] def __sub__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides - operator""" if type(self) is type(other): return TorchLike(self.array.squeeze() - other.array.squeeze()) else: return TorchLike(self.array.squeeze() - other.squeeze())
[docs] def __rsub__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides - operator""" if type(self) is type(other): return TorchLike(other.array.squeeze() - self.array.squeeze()) else: return TorchLike(other.squeeze() - self.array.squeeze())
[docs] def __neg__(self) -> "TorchLike": """Overrides - operator""" return TorchLike(-self.array)
[docs] class TorchLikeFactory(ArrayLikeFactory): @staticmethod
[docs] def zeros(*x: int) -> "TorchLike": """ Returns: TorchLike: zero matrix of dimension *x """ return TorchLike(torch.zeros(x))
@staticmethod
[docs] def eye(x: int) -> "TorchLike": """ Args: x (int): dimension Returns: TorchLike: identity matrix of dimension x """ return TorchLike(torch.eye(x))
@staticmethod
[docs] def array(x: ntp.ArrayLike) -> "TorchLike": """ Returns: TorchLike: vector wrapping x """ return TorchLike(torch.tensor(x))
[docs] class SpatialMath(SpatialMath): def __init__(self): super().__init__(TorchLikeFactory()) @staticmethod
[docs] def sin(x: ntp.ArrayLike) -> "TorchLike": """ Args: x (ntp.ArrayLike): angle value Returns: TorchLike: sin value of x """ if isinstance(x, float): x = torch.tensor(x) return TorchLike(torch.sin(x))
@staticmethod
[docs] def cos(x: ntp.ArrayLike) -> "TorchLike": """ Args: x (ntp.ArrayLike): angle value Returns: TorchLike: cos value of x """ # transform to torch tensor, if not already if isinstance(x, float): x = torch.tensor(x) return TorchLike(torch.cos(x))
@staticmethod
[docs] def outer(x: ntp.ArrayLike, y: ntp.ArrayLike) -> "TorchLike": """ Args: x (ntp.ArrayLike): vector y (ntp.ArrayLike): vector Returns: TorchLike: outer product of x and y """ return TorchLike(torch.outer(torch.tensor(x), torch.tensor(y)))
@staticmethod
[docs] def skew(x: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """ Args: x (Union[TorchLike, ntp.ArrayLike]): vector Returns: TorchLike: skew matrix from x """ if not isinstance(x, TorchLike): return TorchLike( torch.tensor([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]]) ) x = x.array return TorchLike( torch.tensor([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]]) )
@staticmethod
[docs] def vertcat(*x: ntp.ArrayLike) -> "TorchLike": """ Returns: TorchLike: vertical concatenation of x """ if isinstance(x[0], TorchLike): v = torch.vstack([x[i].array for i in range(len(x))]) else: v = torch.tensor(x) return TorchLike(v)
@staticmethod
[docs] def horzcat(*x: ntp.ArrayLike) -> "TorchLike": """ Returns: TorchLike: horizontal concatenation of x """ if isinstance(x[0], TorchLike): v = torch.hstack([x[i].array for i in range(len(x))]) else: v = torch.tensor(x) return TorchLike(v)