Skip to content

Commit

Permalink
feat: adding TRTensor class and random_tr function for the Tensorizat…
Browse files Browse the repository at this point in the history
…ion Function API (ivy-llc#22197)

Co-authored-by: Anwaar Khalid <[email protected]>
  • Loading branch information
2 people authored and iababio committed Sep 27, 2023
1 parent 949af35 commit 3f3fb2f
Show file tree
Hide file tree
Showing 12 changed files with 456 additions and 5 deletions.
13 changes: 11 additions & 2 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class CPTensor:
pass


class TRTensor:
pass


class Parafac2Tensor:
pass

Expand Down Expand Up @@ -766,8 +770,13 @@ class Node(str):
add_ivy_container_instance_methods,
)
from .data_classes.nested_array import NestedArray
from .data_classes.factorized_tensor import TuckerTensor, CPTensor, Parafac2Tensor
from .data_classes.factorized_tensor import TuckerTensor, CPTensor, TTTensor
from .data_classes.factorized_tensor import (
TuckerTensor,
CPTensor,
TRTensor,
TTTensor,
Parafac2Tensor,
)
from ivy.utils.backend import (
current_backend,
compiled_backends,
Expand Down
1 change: 1 addition & 0 deletions ivy/data_classes/factorized_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .tucker_tensor import TuckerTensor
from .cp_tensor import CPTensor
from .tr_tensor import TRTensor
from .parafac2_tensor import Parafac2Tensor
from .tt_tensor import TTTensor
197 changes: 197 additions & 0 deletions ivy/data_classes/factorized_tensor/tr_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# local

from .base import FactorizedTensor
import ivy

# global
import warnings


class TRTensor(FactorizedTensor):
def __init__(self, factors):
super().__init__()
shape, rank = TRTensor.validate_tr_tensor(factors)
self.shape = tuple(shape)
self.rank = tuple(rank)
self.factors = factors

# Built-ins #
# ----------#
def __getitem__(self, index):
return self.factors[index]

def __setitem__(self, index, value):
self.factors[index] = value

def __iter__(self):
for index in range(len(self)):
yield self[index]

def __len__(self):
return len(self.factors)

def __repr__(self):
message = (
f"factors list : rank-{self.rank} tensor ring tensor of shape {self.shape}"
)
return message

# Public Methods #
# ---------------#

def to_tensor(self):
return TRTensor.tr_to_tensor(self.factors)

def to_unfolded(self, mode):
return TRTensor.tr_to_unfolded(self.factors, mode)

def to_vec(self):
return TRTensor.tr_to_vec(self.factors)

# Properties #
# ---------------#
@property
def n_param(self):
factors = self.factors
total_params = sum(int(ivy.prod(tensor.shape)) for tensor in factors)
return total_params

# Class Methods #
# ---------------#
@staticmethod
def validate_tr_tensor(factors):
n_factors = len(factors)

if n_factors < 2:
raise ValueError(
"A Tensor Ring tensor should be composed of at least two factors."
f"However, {n_factors} factor was given."
)

rank = []
shape = []
next_rank = None
for index, factor in enumerate(factors):
current_rank, current_shape, next_rank = ivy.shape(factor)

# Check that factors are third order tensors
if not len(factor.shape) == 3:
raise ValueError(
"TR expresses a tensor as third order factors (tr-cores).\n"
f"However, ivy.ndim(factors[{index}]) = {len(factor.shape)}"
)

# Consecutive factors should have matching ranks
if ivy.shape(factors[index - 1])[2] != current_rank:
raise ValueError(
"Consecutive factors should have matching ranks\n -- e.g."
" ivy.shape(factors[0])[2]) == ivy.shape(factors[1])[0])\nHowever,"
f" ivy.shape(factor[{index-1}])[2] =="
f" {ivy.shape(factors[index-1])[2]} but"
f" ivy.shape(factor[{index}])[0] == {current_rank}"
)

shape.append(current_shape)
rank.append(current_rank)

# Add last rank (boundary condition)
rank.append(next_rank)

return tuple(shape), tuple(rank)

@staticmethod
def tr_to_tensor(factors):
full_shape = [f.shape[1] for f in factors]
full_tensor = ivy.reshape(factors[0], (-1, factors[0].shape[2]))

for factor in factors[1:-1]:
rank_prev, _, rank_next = factor.shape
factor = ivy.reshape(factor, (rank_prev, -1))
full_tensor = ivy.dot(full_tensor, factor)
full_tensor = ivy.reshape(full_tensor, (-1, rank_next))

full_tensor = ivy.reshape(
full_tensor, (factors[-1].shape[2], -1, factors[-1].shape[0])
)
full_tensor = ivy.moveaxis(full_tensor, 0, -1)
full_tensor = ivy.reshape(
full_tensor, (-1, factors[-1].shape[0] * factors[-1].shape[2])
)
factor = ivy.moveaxis(factors[-1], -1, 1)
factor = ivy.reshape(factor, (-1, full_shape[-1]))
full_tensor = ivy.dot(full_tensor, factor)
return ivy.reshape(full_tensor, full_shape)

@staticmethod
def tr_to_unfolded(factors, mode):
return ivy.unfold(TRTensor.tr_to_tensor(factors), mode)

@staticmethod
def tr_to_vec(factors):
return ivy.reshape(
TRTensor.tr_to_tensor(factors),
(-1,),
)

@staticmethod
def validate_tr_rank(tensor_shape, rank="same", rounding="round"):
if rounding == "ceil":
rounding_fun = ivy.ceil
elif rounding == "floor":
rounding_fun = ivy.floor
elif rounding == "round":
rounding_fun = ivy.round
else:
raise ValueError(
f"Rounding should be round, floor or ceil, but got {rounding}"
)

if rank == "same":
rank = float(1)

n_dim = len(tensor_shape)
if n_dim == 2:
warnings.warn(
"Determining the TR-rank for the trivial case of a matrix"
f" (order 2 tensor) of shape {tensor_shape}, not a higher-order tensor."
)

if isinstance(rank, float):
# Choose the *same* rank for each mode
n_param_tensor = ivy.prod(tensor_shape) * rank

# R_k I_k R_{k+1} = R^2 I_k
solution = int(
rounding_fun(ivy.sqrt(n_param_tensor / ivy.sum(tensor_shape)))
)
rank = (solution,) * (n_dim + 1)

else:
# Check user input for potential errors
n_dim = len(tensor_shape)
if isinstance(rank, int):
rank = (rank,) * (n_dim + 1)
elif n_dim + 1 != len(rank):
message = (
"Provided incorrect number of ranks. Should verify len(rank) =="
f" len(tensor.shape)+1, but len(rank) = {len(rank)} while"
f" len(tensor.shape)+1 = {n_dim + 1}"
)
raise ValueError(message)

# Check first and last rank
if rank[0] != rank[-1]:
message = (
f"Provided rank[0] == {rank[0]} and rank[-1] == {rank[-1]}"
" but boundary conditions dictate rank[0] == rank[-1]"
)
raise ValueError(message)

return list(rank)

@staticmethod
def tr_n_param(tensor_shape, rank):
factor_params = []
for i, s in enumerate(tensor_shape):
factor_params.append(rank[i] * s * rank[i + 1])
return ivy.sum(factor_params)
5 changes: 4 additions & 1 deletion ivy/functional/backends/tensorflow/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# local
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
from ivy.functional.ivy.random import (
_check_bounds_and_get_shape,
_randint_check_dtype_and_bound,
Expand All @@ -26,6 +26,9 @@
# ------#


@with_supported_dtypes(
{"2.13.0 and below": ("float", "int32", "int64")}, backend_version
)
def random_uniform(
*,
low: Union[float, tf.Tensor, tf.Variable] = 0.0,
Expand Down
63 changes: 61 additions & 2 deletions ivy/functional/ivy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,64 @@ def random_cp(
return ivy.CPTensor((weights, factors))


@handle_exceptions
@handle_nestable
@infer_dtype
def random_tr(
shape: Sequence[int],
rank: Sequence[int],
/,
*,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
full: Optional[bool] = False,
seed: Optional[int] = None,
) -> Union[ivy.TRTensor, ivy.Array]:
"""
Generate a random TR tensor.
Parameters
----------
shape : tuple
shape of the tensor to generate
rank : Sequence[int]
rank of the TR decomposition
must verify rank[0] == rank[-1] (boundary conditions)
and len(rank) == len(shape)+1
full : bool, optional, default is False
if True, a full tensor is returned
otherwise, the decomposed tensor is returned
seed :
seed for generating random numbers
context : dict
context in which to create the tensor
Returns
-------
ivy.TRTensor or ivy.Array if full is True
"""
rank = ivy.TRTensor.validate_tr_rank(shape, rank)
# Make sure it's not a tuple but a list
rank = list(rank)
_check_first_and_last_rank_elements_are_equal(rank)
factors = [
ivy.random_uniform(shape=(rank[i], s, rank[i + 1]), dtype=dtype, seed=seed)
for i, s in enumerate(shape)
]
if full:
return ivy.TRTensor.tr_to_tensor(factors)
else:
return ivy.TRTensor(factors)


def _check_first_and_last_rank_elements_are_equal(rank):
if rank[0] != rank[-1]:
message = (
f"Provided rank[0] == {rank[0]} and rank[-1] == {rank[-1]} "
"but boundary conditions dictate rank[0] == rank[-1]."
)
raise ValueError(message)


@handle_exceptions
@handle_nestable
@infer_dtype
Expand Down Expand Up @@ -993,8 +1051,9 @@ def trilu(
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Return the upper or lower triangular part of a matrix (or a stack of matrices) ``x``
.. note::
Return the upper or lower triangular part of a matrix
(or a stack of matrices) ``x``.
note::
The upper triangular part of the matrix is defined as the elements
on and above the specified diagonal ``k``. The lower triangular part
of the matrix is defined as the elements on and below the specified
Expand Down
Loading

0 comments on commit 3f3fb2f

Please sign in to comment.