diff --git a/tf2jax/_src/xla_utils.py b/tf2jax/_src/xla_utils.py index e97f53e..9793239 100644 --- a/tf2jax/_src/xla_utils.py +++ b/tf2jax/_src/xla_utils.py @@ -68,16 +68,24 @@ def gather_dimension_numbers_from_proto( message) -> jax.lax.GatherDimensionNumbers: proto = xla_data_pb2.GatherDimensionNumbers().FromString(message) return jax.lax.GatherDimensionNumbers( - tuple(proto.offset_dims), tuple(proto.collapsed_slice_dims), - tuple(proto.start_index_map)) + tuple(proto.offset_dims), + tuple(proto.collapsed_slice_dims), + tuple(proto.start_index_map), + tuple(proto.operand_batching_dims), + tuple(proto.start_indices_batching_dims), + ) def scatter_dimension_numbers_from_proto( message) -> jax.lax.ScatterDimensionNumbers: proto = xla_data_pb2.ScatterDimensionNumbers().FromString(message) return jax.lax.ScatterDimensionNumbers( - tuple(proto.update_window_dims), tuple(proto.inserted_window_dims), - tuple(proto.scatter_dims_to_operand_dims)) + tuple(proto.update_window_dims), + tuple(proto.inserted_window_dims), + tuple(proto.scatter_dims_to_operand_dims), + tuple(proto.input_batching_dims), + tuple(proto.scatter_indices_batching_dims), + ) def precision_config_from_proto(