Skip to content

Commit

Permalink
Merge branch 'main' of github.com:kamlishgoswami/ivy into kamlish
Browse files Browse the repository at this point in the history
  • Loading branch information
kamlishgoswami committed Oct 24, 2023
2 parents 317f6bb + 03602a2 commit a4e9425
Show file tree
Hide file tree
Showing 16 changed files with 837 additions and 16 deletions.
2 changes: 0 additions & 2 deletions ivy/functional/backends/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ def rep_method(self, ufunc, method, *inputs, **kwargs):
"bitwise_and": "bitwise_and",
"matmul": "matmul",
"power": "pow",
"divide": "divide",
"subtract": "subtract",
"add": "add",
"not_equal": "not_equal",
}
if ufunc.__name__ in methods.keys():
return eval("ivy." + methods[ufunc.__name__] + "(*inputs, **kwargs)")
Expand Down
5 changes: 5 additions & 0 deletions ivy/functional/frontends/jax/numpy/mathematical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ def inner(a, b):
return ivy.inner(a, b)


@to_ivy_arrays_and_back
def interp(x, xp, fp, left=None, right=None, period=None):
return ivy.interp(x, xp, fp, left=left, right=right, period=period)


@to_ivy_arrays_and_back
def kron(a, b):
a, b = promote_types_of_jax_inputs(a, b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def ifft(a, n=None, axis=-1, norm=None):
return ivy.ifft(a, axis, norm=norm, n=n)


@with_unsupported_dtypes({"1.24.3 and below": ("float16",)}, "numpy")
@to_ivy_arrays_and_back
def ifft2(a, s=None, axes=(-2, -1), norm=None):
a = ivy.asarray(a, dtype=ivy.complex128)
a = ivy.ifftn(a, s=s, axes=axes, norm=norm)
return a


@with_unsupported_dtypes({"1.24.3 and below": ("float16",)}, "numpy")
@to_ivy_arrays_and_back
def ifftn(a, s=None, axes=None, norm=None):
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/frontends/numpy/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,9 @@ def var(
where=where,
)

def __irshift__(self, value, /):
return ivy.bitwise_right_shift(self.ivy_array, value, out=self)


# --- Helpers --- #
# --------------- #
Expand Down
9 changes: 9 additions & 0 deletions ivy/functional/frontends/paddle/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ def unbind(input, axis=0):
return tuple([x.reshape(tuple(shape)) for x in split(input, num_splits, axis=axis)])


@with_supported_dtypes(
{"2.5.1 and below": ("bool", "int32", "int64", "float16", "float32", "float64")},
"paddle",
)
@to_ivy_arrays_and_back
def unique_consecutive(x, axis=0):
return ivy.unique_consecutive(x, axis=axis)


@with_supported_dtypes(
{
"2.5.1 and below": (
Expand Down
10 changes: 10 additions & 0 deletions ivy/functional/frontends/paddle/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
)


@with_supported_dtypes(
{"2.5.1 and below": ("float32", "float64")},
"paddle",
)
@to_ivy_arrays_and_back
def multinomial(x, num_samples=1, replacement=False, name=None):
n = num_samples + 1
return ivy.multinomial(n, num_samples, probs=x, replace=replacement)


@with_supported_dtypes(
{"2.5.1 and below": ("float32", "float64")},
"paddle",
Expand Down
23 changes: 23 additions & 0 deletions ivy/functional/frontends/paddle/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,10 +865,33 @@ def fill_(self, value):
def unbind(self, axis=0):
return paddle_frontend.unbind(self._ivy_array, axis=axis)

@with_supported_dtypes(
{
"2.5.1 and below": (
"bool",
"int32",
"int64",
"float16",
"float32",
"float64",
)
},
"paddle",
)
def unique_consecutive(self, axis=0):
return paddle_frontend.unique_consecutive(self._ivy_array, axis=axis)

def cpu(self):
self.ivy_array = ivy.to_device(self.ivy_array, ivy.as_ivy_dev("cpu"))
return self

@with_unsupported_dtypes(
{"2.5.1 and below": ("int16", "complex64", "complex128")},
"paddle",
)
def split(self, num_or_sections, axis=0, name=None):
return paddle_frontend.split(self._ivy_array, num_or_sections, axis, name)

@with_supported_dtypes(
{"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle"
)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def check_tensorflow_casting(x1, x2):
from . import ragged
from .ragged import *
from . import tensor
from .tensor import EagerTensor, Tensor
from .tensor import EagerTensor, Tensor, TensorArray
from . import variable
from .variable import Variable, IndexedSlices
from . import keras
Expand Down
213 changes: 213 additions & 0 deletions ivy/functional/frontends/tensorflow/tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# global
import weakref

# local
import ivy
Expand Down Expand Up @@ -228,6 +229,218 @@ def __iter__(self):
yield self[i]


class TensorArray:
def __init__(
self,
dtype,
size=None,
dynamic_size=None,
clear_after_read=None,
tensor_array_name=None,
handle=None,
flow=None,
infer_shape=True,
element_shape=None,
colocate_with_first_write_call=True,
name=None,
):
del (flow, tensor_array_name, name)
self._handle = None
self._flow = tf_frontend.constant(0, dtype=tf_frontend.int32)
self._infer_shape = infer_shape
self._element_shape = (
ivy.Shape(element_shape) if element_shape is not None else element_shape
)
self._colocate_with_first_write_call = colocate_with_first_write_call
self._dtype = tf_frontend.as_dtype(dtype)
self._dynamic_size = dynamic_size or False
self._clear_after_read = True if clear_after_read is None else clear_after_read
self._previously_read_indices = []

if isinstance(size, EagerTensor):
size = size.ivy_array
self._tensor_array = [None for _ in range(size)]
self._parent = weakref.ref(self)

@property
def flow(self):
return self._flow

@property
def dtype(self):
return self._dtype

@property
def handle(self):
return self._handle

@property
def element_shape(self):
return self._element_shape

def identity(self):
return self._parent()

def grad(self, source, flow=None, name=None):
raise NotImplementedError(
"TensorArray.grad is not supported when executing eagerly; eager's "
"gradient implementation does not use/need this function to compute "
"gradients of operations that use TensorArrays."
)

@property
def dynamic_size(self):
return self._dynamic_size

@property
def infer_shape(self):
return self._infer_shape

def read(self, index, name=None):
if isinstance(index, EagerTensor):
index = ivy.to_scalar(index.ivy_array)

if index < 0:
raise IndexError(f"Reading from negative indices {index} is not allowed.")

if index >= len(self._tensor_array):
raise IndexError(
f"Tried to read from index {index} but array size is:"
f" {len(self._tensor_array)} "
)

tensor = self._tensor_array[index]
if tensor is None:
if index in self._previously_read_indices:
raise ValueError(
f"Could not read index {index} twice because it was cleared after a"
" previous read (perhaps try setting clear_after_read = false?)"
)
else:
tensor = self._tensor_array[index] = tf_frontend.zeros(
shape=self._element_shape, dtype=self._dtype
)

if self._clear_after_read:
self._tensor_array[index] = None
self._previously_read_indices.append(index)
return tensor

def _write(self, index, value, name=None):
if isinstance(index, EagerTensor):
index = ivy.to_scalar(index.ivy_array)

if index < 0:
raise IndexError(f"Reading from negative indices {index} is not allowed.")

size = len(self._tensor_array)
if index >= size:
if not self._dynamic_size:
raise IndexError(
"Tried to write to index {index} but array is not resizeable and"
" size is: {size}"
)
self._tensor_array.extend(None for _ in range(index - size + 1))

if not isinstance(value, EagerTensor):
value = tf_frontend.cast(value, self.dtype)

if self._dtype != value.dtype:
raise ValueError(
f"TensorArray dtype is {self._dtype} but Op is trying to write dtype"
f" {value.dtype} "
)

if self._infer_shape:
self._element_shape = self._merge_shape(value)

self._tensor_array[index] = value

def _merge_shape(self, value):
if self._element_shape is None:
return value.shape
if len(self._element_shape) != len(value.shape):
raise ValueError("Shapes not compatible")
shape = []
for a, b in zip(self._element_shape, value.shape):
if a == b or a is None:
shape.append(b)
else:
raise ValueError("Shapes not compatible")
return tuple(shape)

def write(self, index, value, name=None):
self._write(index, value)
return self._parent()

def stack(self, name=None):
if self._tensor_array:
for ix in range(len(self._tensor_array)):
if self._tensor_array[ix] is None:
self._tensor_array[ix] = tf_frontend.zeros(
shape=self._element_shape, dtype=self._dtype
)
if not self._tensor_array and self._element_shape.is_fully_defined():
return tf_frontend.constant(
[0] + list(self.element_shape), dtype=self._dtype
)
else:
return tf_frontend.stack(self._tensor_array)

def _maybe_zero(self, ix):
val = self._tensor_array[ix]
if val is None:
val = self._tensor_array[ix] = tf_frontend.zeros(
shape=self._element_shape, dtype=self._dtype
)
return val

def gather(self, indices, name=None):
if isinstance(indices, EagerTensor):
indices = indices.ivy_array
return tf_frontend.stack([self._maybe_zero(i) for i in indices])

def concat(self, name=None):
return tf_frontend.concat(
[self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
0,
name=name,
)

def unstack(self, value, name=None):
tensors = tf_frontend.unstack(value, name=name)
if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
raise ValueError(
f"Cannot unstack {len(tensors)} tensors into a TensorArray of static"
f" size {len(self._tensor_array)} "
)
self._tensor_array = tensors
return self._parent()

def scatter(self, indices, value, name=None):
if isinstance(indices, EagerTensor):
indices = indices.ivy_array
for index, val in zip(indices, tf_frontend.unstack(value)):
self._write(index, val)
return self._parent()

def size(self, name=None):
return tf_frontend.constant(len(self._tensor_array))

def close(self, name=None):
del self._tensor_array[:]

def split(self, value, lengths, name=None):
value = tf_frontend.cast(value, self.dtype)
lengths = (
tf_frontend.constant(lengths)
if not isinstance(lengths, EagerTensor)
else lengths
)
self._tensor_array = tf_frontend.split(value, lengths, name=name)
return self._parent()


# Dummy Tensor class to help with compilation, don't add methods here
class Tensor(EagerTensor):
pass
Loading

0 comments on commit a4e9425

Please sign in to comment.