diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index a41839a..d511f7e 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -1727,7 +1727,7 @@ def _func( @register_operation("ScatterNd") def _scatter_nd(proto): """Parse a ScatterNd op.""" - _check_attrs(proto, {"T", "Tindices"}) + _check_attrs(proto, {"T", "Tindices", "bad_indices_policy"}) def _func( indices: jnp.ndarray,