From 059bc5ef8bed8f80c0b33d2b99774d94fa552b68 Mon Sep 17 00:00:00 2001 From: TF2JAXDev Date: Mon, 27 May 2024 08:15:25 -0700 Subject: [PATCH] Implement "bad_indcies_policy" for ScatterNd. For testing, we also introduced "ScatterNdTest" that verifies the existing "default" behavior for comparison. PiperOrigin-RevId: 637646019 --- tf2jax/_src/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index a41839a..d511f7e 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -1727,7 +1727,7 @@ def _func( @register_operation("ScatterNd") def _scatter_nd(proto): """Parse a ScatterNd op.""" - _check_attrs(proto, {"T", "Tindices"}) + _check_attrs(proto, {"T", "Tindices", "bad_indices_policy"}) def _func( indices: jnp.ndarray,