Skip to content

Commit

Permalink
Fix static shape for take_along_axis with Tensorflow. (keras-team#1…
Browse files Browse the repository at this point in the history
…9656)

Note that this fix will replace a different fix in Keras-nlp that addresses the same issue https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/src/backend/ops.py#L26
  • Loading branch information
hertschuh authored May 3, 2024
1 parent 1941c30 commit e4f5092
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 27 deletions.
17 changes: 12 additions & 5 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,10 @@ def take(x, indices, axis=None):


def take_along_axis(x, indices, axis=None):
from keras.src.ops.operation_utils import (
compute_take_along_axis_output_shape,
)

x = convert_to_tensor(x)
indices = convert_to_tensor(indices, "int64")
if axis is None:
Expand All @@ -1959,7 +1963,13 @@ def take_along_axis(x, indices, axis=None):
f"Received: indices.shape={indices.shape}"
)
return take_along_axis(tf.reshape(x, [-1]), indices, 0)
rank = tf.rank(x)

# Compute the static output shape as later on, all shapes manipulations
# use dynamic shapes.
static_output_shape = compute_take_along_axis_output_shape(
x.shape, indices.shape, axis
)
rank = x.ndim
static_axis = axis
axis = axis + rank if axis < 0 else axis

Expand All @@ -1981,9 +1991,6 @@ def take_along_axis(x, indices, axis=None):
x = tf.broadcast_to(x, x_shape)
indices = tf.broadcast_to(indices, indices_shape)

# Save indices shape so we can restore it later.
possible_result_shape = indices.shape

# Correct the indices using "fill" mode which is the same as in jax
indices = tf.where(indices < 0, indices + x_shape[static_axis], indices)

Expand All @@ -1998,7 +2005,7 @@ def take_along_axis(x, indices, axis=None):
result = tf.gather(x, indices, batch_dims=1)
result = tf.reshape(result, indices_shape)
result = swapaxes(result, static_axis, -1)
result.set_shape(possible_result_shape)
result.set_shape(static_output_shape)
return result


Expand Down
25 changes: 3 additions & 22 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4859,28 +4859,9 @@ def call(self, x, indices):
return backend.numpy.take_along_axis(x, indices, axis=self.axis)

def compute_output_spec(self, x, indices):
x_shape = list(x.shape)
indices_shape = list(indices.shape)
if self.axis is None:
x_shape = [None] if None in x_shape else [int(np.prod(x_shape))]

if len(x_shape) != len(indices_shape):
raise ValueError(
"`x` and `indices` must have the same number of dimensions, "
f"but receive shape {x_shape} and {indices_shape}."
)

del x_shape[self.axis]
del indices_shape[self.axis]
output_shape = broadcast_shapes(x_shape, indices_shape)
size_on_axis = indices.shape[self.axis]
if self.axis == -1:
output_shape = output_shape + [size_on_axis]
elif self.axis >= 0:
output_shape.insert(self.axis, size_on_axis)
else:
output_shape.insert(self.axis + 1, size_on_axis)

output_shape = operation_utils.compute_take_along_axis_output_shape(
x.shape, indices.shape, self.axis
)
return KerasTensor(output_shape, dtype=x.dtype)


Expand Down
19 changes: 19 additions & 0 deletions keras/src/ops/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,25 @@ def compute_transpose_output_shape(input_shape, axes):
return tuple(input_shape[ax] for ax in axes)


def compute_take_along_axis_output_shape(input_shape, indices_shape, axis):
input_shape = list(input_shape)
indices_shape = list(indices_shape)
if axis is None:
input_shape = (
[None] if None in input_shape else [int(np.prod(input_shape))]
)

if len(input_shape) != len(indices_shape):
raise ValueError(
"`x` and `indices` must have the same number of dimensions, "
f"but receive shape {input_shape} and {indices_shape}."
)

input_shape[axis] = indices_shape[axis]
output_shape = broadcast_shapes(input_shape, indices_shape)
return output_shape


def reduce_shape(shape, axis=None, keepdims=False):
shape = list(shape)
if axis is None:
Expand Down

0 comments on commit e4f5092

Please sign in to comment.