Skip to content

Commit

Permalink
ScatterNdFunctor to update operands of all valid indices and contin…
Browse files Browse the repository at this point in the history
…ue on bad indices.

This is to support the new attribute "bad_indices_policy". Passing downs the behavior also works, but it makes `ScatterNdFunctor` unnecessarily complicated while the only gain is the performance with out-of-bound error.

PiperOrigin-RevId: 637646021
  • Loading branch information
TF2JAXDev authored and TF2JAXDev committed May 29, 2024
1 parent ad9afbb commit 451e715
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ def _func(
@register_operation("ScatterNd")
def _scatter_nd(proto):
"""Parse a ScatterNd op."""
_check_attrs(proto, {"T", "Tindices"})
_check_attrs(proto, {"T", "Tindices", "bad_indices_policy"})

def _func(
indices: jnp.ndarray,
Expand Down

0 comments on commit 451e715

Please sign in to comment.