Skip to content

Commit

Permalink
fix: for tf.function compatibility in tf backend (#28665)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbarrett98 authored Apr 4, 2024
1 parent ccf1894 commit cd1ac9b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
6 changes: 3 additions & 3 deletions ivy/functional/backends/tensorflow/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def arange(

# convert builtin types to tf scalars, as is expected by tf.range
if isinstance(start, (float, int)):
start = tf.constant(start)
start = tf.convert_to_tensor(start)
if isinstance(stop, (float, int)):
stop = tf.constant(stop)
stop = tf.convert_to_tensor(stop)
if isinstance(step, (float, int)):
step = tf.constant(step)
step = tf.convert_to_tensor(step)

if dtype is None:
if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def unique_consecutive(
x_shape = None
if axis is None:
x_shape = x.shape
x = tf.reshape(x, -1)
x = tf.reshape(x, tf.constant([-1]))
axis = -1
ndim = len(x.shape)
if axis < 0:
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def dtype(self):

@property
def shape(self):
return Size(self.ivy_array.shape)
return Size(ivy.shape(self.ivy_array, as_array=True))

@property
def real(self):
Expand Down
5 changes: 4 additions & 1 deletion ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2978,7 +2978,10 @@ def _parse_query(query, x_shape, scatter=False):
[list(query[i].shape) for i in range(0, array_inds[0])]
+ [list(ivy.shape(array_queries[0], as_array=True))]
+ [[] for _ in range(len(array_inds) - 1)]
+ [list(query[i].shape) for i in range(array_inds[-1] + 1, len(query))]
+ [
list(ivy.shape(query[i], as_array=True))
for i in range(array_inds[-1] + 1, len(query))
]
)
else:
target_shape = [list(q.shape) for q in query]
Expand Down

0 comments on commit cd1ac9b

Please sign in to comment.