Skip to content

Commit

Permalink
Implement "bad_indcies_policy" for ScatterNd.
Browse files Browse the repository at this point in the history
For testing, we also introduced "ScatterNdTest" that verifies the existing "default" behavior for comparison.

PiperOrigin-RevId: 637646019
  • Loading branch information
TF2JAXDev authored and TF2JAXDev committed May 29, 2024
1 parent ad9afbb commit 059bc5e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ def _func(
@register_operation("ScatterNd")
def _scatter_nd(proto):
"""Parse a ScatterNd op."""
_check_attrs(proto, {"T", "Tindices"})
_check_attrs(proto, {"T", "Tindices", "bad_indices_policy"})

def _func(
indices: jnp.ndarray,
Expand Down

0 comments on commit 059bc5e

Please sign in to comment.