From e2b43e2e74dfea55105763484b47a082da13b0e7 Mon Sep 17 00:00:00 2001 From: james77777778 <20734616+james77777778@users.noreply.github.com> Date: Fri, 15 Mar 2024 23:53:05 +0800 Subject: [PATCH] Replace `dm-tree` with `optree` (#19306) * Refactor `keras.utils.tree` * Fix tests * Replace `dm-tree` with `optree` * Eliminate `tf.nest` * Resolve comments * Fix merge conflicts * Update exporting path --- keras/backend/tensorflow/core.py | 7 +- keras/backend/tensorflow/layer.py | 13 +- keras/backend/tensorflow/numpy.py | 23 +- keras/backend/tensorflow/trainer.py | 14 +- keras/export/export_lib.py | 21 +- keras/models/cloning.py | 2 +- keras/ops/core_test.py | 5 +- keras/utils/tracking.py | 36 ++ keras/utils/tree.py | 567 ++++++++++++++++++++++++---- keras/utils/tree_test.py | 291 ++++++++++++++ requirements-common.txt | 2 +- setup.py | 2 +- 12 files changed, 868 insertions(+), 115 deletions(-) create mode 100644 keras/utils/tree_test.py diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index 0e781aa0c09..a68f2ce203d 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -9,6 +9,7 @@ from keras.backend.common.name_scope import name_scope as base_name_scope from keras.backend.common.stateless_scope import StatelessScope from keras.backend.common.stateless_scope import in_stateless_scope +from keras.utils import tree from keras.utils.naming import auto_name SUPPORTS_SPARSE_TENSORS = True @@ -189,7 +190,7 @@ def convert_keras_tensor_to_tf(x): ) return x - args, kwargs = tf.nest.map_structure( + args, kwargs = tree.map_structure( convert_keras_tensor_to_tf, (args, kwargs) ) tf_out = fn(*args, **kwargs) @@ -201,9 +202,7 @@ def convert_tf_to_keras_tensor(x): ) return x - output_spec = tf.nest.map_structure( - convert_tf_to_keras_tensor, tf_out - ) + output_spec = tree.map_structure(convert_tf_to_keras_tensor, tf_out) return output_spec diff --git a/keras/backend/tensorflow/layer.py b/keras/backend/tensorflow/layer.py index 2a1b6663468..9b586bc9ec9 100644 --- a/keras/backend/tensorflow/layer.py +++ b/keras/backend/tensorflow/layer.py @@ -3,6 +3,7 @@ from keras.backend.tensorflow.trackable import KerasAutoTrackable from keras.utils import tf_utils from keras.utils import tracking +from keras.utils import tree class TFLayer(KerasAutoTrackable): @@ -27,16 +28,16 @@ def _set_save_spec(self, inputs, args=None, kwargs=None): if self._saved_model_inputs_spec is not None: return # Already set. - inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs) - args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or []) + inputs_spec = tree.map_structure(tf_utils.get_tensor_spec, inputs) + args_spec = tree.map_structure(tf_utils.get_tensor_spec, args or []) kwargs_spec = {} # Filter out non-tensor arguments from kwargs. for key, kwarg in kwargs.items(): - flat_kwarg = tf.nest.flatten(kwarg) + flat_kwarg = tree.flatten(kwarg) flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg] if any(s is None for s in flat_specs): continue - kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs) + kwargs_spec[key] = tree.pack_sequence_as(kwarg, flat_specs) self._saved_model_inputs_spec = inputs_spec self._saved_model_arg_spec = ( @@ -94,7 +95,7 @@ def _default_save_signature(self): if inputs is not None: input_signature = [ - tf.nest.map_structure( + tree.map_structure( lambda x: tf.TensorSpec(x.shape, self.compute_dtype), inputs, ) @@ -108,7 +109,7 @@ def _default_save_signature(self): ] else: input_signature = [ - tf.nest.map_structure( + tree.map_structure( lambda x: tf.TensorSpec(x.shape, self.compute_dtype), shapes_dict, ) diff --git a/keras/backend/tensorflow/numpy.py b/keras/backend/tensorflow/numpy.py index 1c652a5c17e..24b5ae13dfb 100644 --- a/keras/backend/tensorflow/numpy.py +++ b/keras/backend/tensorflow/numpy.py @@ -17,6 +17,7 @@ from keras.backend.common.backend_utils import to_tuple_or_list from keras.backend.tensorflow import sparse from keras.backend.tensorflow.core import convert_to_tensor +from keras.utils import tree @sparse.elementwise_binary_union(tf.sparse.add) @@ -95,7 +96,7 @@ def _normalize_einsum_subscripts(subscripts): def einsum(subscripts, *operands, **kwargs): - operands = tf.nest.map_structure(convert_to_tensor, operands) + operands = tree.map_structure(convert_to_tensor, operands) subscripts = _normalize_einsum_subscripts(subscripts) def is_valid_for_custom_ops(subscripts, *operands): @@ -240,7 +241,7 @@ def use_custom_ops(subscripts, *operands, output_type): # output_type="int32" if "int" in compute_dtype and output_type is None: compute_dtype = config.floatx() - operands = tf.nest.map_structure( + operands = tree.map_structure( lambda x: tf.cast(x, compute_dtype), operands ) result = use_custom_ops(subscripts, *operands, output_type=output_type) @@ -248,7 +249,7 @@ def use_custom_ops(subscripts, *operands, output_type): # TODO: tf.einsum doesn't support integer dtype with gpu if "int" in compute_dtype: compute_dtype = config.floatx() - operands = tf.nest.map_structure( + operands = tree.map_structure( lambda x: tf.cast(x, compute_dtype), operands ) result = tf.einsum(subscripts, *operands, **kwargs) @@ -763,11 +764,11 @@ def concatenate(xs, axis=0): ) for x in xs ] - xs = tf.nest.map_structure(convert_to_tensor, xs) + xs = tree.map_structure(convert_to_tensor, xs) dtype_set = set([x.dtype for x in xs]) if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) - xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs) + xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs) return tf.concat(xs, axis=axis) @@ -872,7 +873,7 @@ def digitize(x, bins): bins = list(bins) # bins must be float type - bins = tf.nest.map_structure(lambda x: float(x), bins) + bins = tree.map_structure(lambda x: float(x), bins) # TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8 # int16, uint8, uint16, uint32 @@ -1023,7 +1024,7 @@ def hstack(xs): dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) - xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs) + xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs) rank = tf.rank(xs[0]) return tf.cond( tf.equal(rank, 1), @@ -1328,9 +1329,7 @@ def ndim(x): def nonzero(x): x = convert_to_tensor(x) result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1) - return tf.nest.map_structure( - lambda indices: tf.cast(indices, "int32"), result - ) + return tree.map_structure(lambda indices: tf.cast(indices, "int32"), result) def not_equal(x1, x2): @@ -1620,7 +1619,7 @@ def stack(x, axis=0): dtype_set = set([getattr(a, "dtype", type(a)) for a in x]) if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) - x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x) + x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x) return tf.stack(x, axis=axis) @@ -1807,7 +1806,7 @@ def vstack(xs): dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) - xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs) + xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs) return tf.concat(xs, axis=0) diff --git a/keras/backend/tensorflow/trainer.py b/keras/backend/tensorflow/trainer.py index f8afcd9ed65..204094ee690 100644 --- a/keras/backend/tensorflow/trainer.py +++ b/keras/backend/tensorflow/trainer.py @@ -225,7 +225,7 @@ def multi_step_on_data(data): outputs = one_step_on_data_distributed(data[:1]) for single_step_data in data[1:]: step_outputs = one_step_on_data_distributed([single_step_data]) - outputs = tf.nest.map_structure( + outputs = tree.map_structure( lambda t1, t2: concat([t1, t2]), outputs, step_outputs ) return outputs @@ -473,7 +473,7 @@ def predict( def append_to_outputs(batch_outputs, outputs): if outputs is None: - outputs = tf.nest.map_structure( + outputs = tree.map_structure( lambda batch_output: [batch_output], batch_outputs, ) @@ -521,7 +521,7 @@ def get_data(iterator): outputs = tree.map_structure_up_to( batch_outputs, potentially_ragged_concat, outputs ) - return tf.nest.map_structure(convert_to_np_if_not_ragged, outputs) + return tree.map_structure(convert_to_np_if_not_ragged, outputs) def train_on_batch( self, @@ -549,7 +549,7 @@ def data(): yield (x, y, sample_weight) logs = self.train_function(data()) - logs = tf.nest.map_structure(lambda x: np.array(x), logs) + logs = tree.map_structure(lambda x: np.array(x), logs) if return_dict: return logs return self._flatten_metrics_in_order(logs) @@ -568,7 +568,7 @@ def data(): yield (x, y, sample_weight) logs = self.test_function(data()) - logs = tf.nest.map_structure(lambda x: np.array(x), logs) + logs = tree.map_structure(lambda x: np.array(x), logs) if return_dict: return logs return self._flatten_metrics_in_order(logs) @@ -576,7 +576,7 @@ def data(): def predict_on_batch(self, x): self.make_predict_function() batch_outputs = self.predict_function([(x,)]) - batch_outputs = tf.nest.map_structure( + batch_outputs = tree.map_structure( convert_to_np_if_not_ragged, batch_outputs ) return batch_outputs @@ -771,7 +771,7 @@ def _reduce(v): f"Received: reduction={reduction}." ) - return tf.nest.map_structure(_reduce, values) + return tree.map_structure(_reduce, values) def _multi_worker_concat(v, strategy): diff --git a/keras/export/export_lib.py b/keras/export/export_lib.py index 3633e58767e..7e607d29215 100644 --- a/keras/export/export_lib.py +++ b/keras/export/export_lib.py @@ -8,6 +8,7 @@ from keras.models import Functional from keras.models import Sequential from keras.utils import io_utils +from keras.utils import tree from keras.utils.module_utils import tensorflow as tf @@ -143,16 +144,16 @@ def track(self, resource): # Variables in the lists below are actually part of the trackables # that get saved, because the lists are created in __init__. if backend.backend() == "jax": - self._tf_trackable.variables += tf.nest.flatten( - tf.nest.map_structure(tf.Variable, resource.variables) + self._tf_trackable.variables += tree.flatten( + tree.map_structure(tf.Variable, resource.variables) ) - self._tf_trackable.trainable_variables += tf.nest.flatten( - tf.nest.map_structure( + self._tf_trackable.trainable_variables += tree.flatten( + tree.map_structure( tf.Variable, resource.trainable_variables ) ) - self._tf_trackable.non_trainable_variables += tf.nest.flatten( - tf.nest.map_structure( + self._tf_trackable.non_trainable_variables += tree.flatten( + tree.map_structure( tf.Variable, resource.non_trainable_variables ) ) @@ -362,9 +363,7 @@ def add_variable_collection(self, name, variables): f"{list(set(type(v) for v in variables))}" ) if backend.backend() == "jax": - variables = tf.nest.flatten( - tf.nest.map_structure(tf.Variable, variables) - ) + variables = tree.flatten(tree.map_structure(tf.Variable, variables)) setattr(self._tf_trackable, name, list(variables)) def write_out(self, filepath, options=None): @@ -470,7 +469,7 @@ def _convert_jax2tf_function(self, fn, input_signature): def _spec_to_poly_shape(self, spec): if isinstance(spec, (dict, list)): - return tf.nest.map_structure(self._spec_to_poly_shape, spec) + return tree.map_structure(self._spec_to_poly_shape, spec) spec_shape = spec.shape spec_shape = str(spec_shape).replace("None", "b") return spec_shape @@ -500,7 +499,7 @@ def export_model(model, filepath): export_archive = ExportArchive() export_archive.track(model) if isinstance(model, (Functional, Sequential)): - input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs) + input_signature = tree.map_structure(_make_tensor_spec, model.inputs) if isinstance(input_signature, list) and len(input_signature) > 1: input_signature = [input_signature] export_archive.add_endpoint("serve", model.__call__, input_signature) diff --git a/keras/models/cloning.py b/keras/models/cloning.py index 95c209b1875..cb5ec923aec 100644 --- a/keras/models/cloning.py +++ b/keras/models/cloning.py @@ -261,7 +261,7 @@ def _clone_layer(layer): ) try: tree.assert_same_structure(input_tensors, model.input) - except TypeError as e: + except (ValueError, TypeError) as e: raise ValueError( "`input_tensors` must have the same structure as model.input" f"\nReference structure: {model.input}" diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index e3dd47ae68d..774b19760d4 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -789,9 +789,8 @@ def test_cond_check_output_spec_list_tuple(self): def test_cond_check_output_spec_other_types(self): cond_op = core.Cond() - # Create mock objects with dtype and shape attributes - mock_spec1 = Mock(dtype="float32", shape=(2, 2)) - mock_spec2 = Mock(dtype="float32", shape=(2, 2)) + mock_spec1 = KerasTensor(shape=(2, 2), dtype="float32") + mock_spec2 = KerasTensor(shape=(2, 2), dtype="float32") self.assertTrue(cond_op._check_output_spec(mock_spec1, mock_spec2)) def test_cond_check_output_spec_none(self): diff --git a/keras/utils/tracking.py b/keras/utils/tracking.py index 27537c27367..f883b3a5541 100644 --- a/keras/utils/tracking.py +++ b/keras/utils/tracking.py @@ -1,5 +1,8 @@ from functools import wraps +import optree +import optree.utils + from keras.backend.common.global_state import get_global_attribute from keras.backend.common.global_state import set_global_attribute from keras.utils import python_utils @@ -110,6 +113,7 @@ def add_to_store(self, store_name, value): self.stored_ids[store_name].add(id(value)) +@optree.register_pytree_node_class(namespace="keras") class TrackedList(list): def __init__(self, values=None, tracker=None): self.tracker = tracker @@ -160,7 +164,17 @@ def __delitem__(self, index): if self.tracker: self.tracker.untrack(value) + def tree_flatten(self): + # For optree + return (self, None) + + @classmethod + def tree_unflatten(cls, metadata, children): + # For optree + return cls(children) + +@optree.register_pytree_node_class(namespace="keras") class TrackedDict(dict): def __init__(self, values=None, tracker=None): self.tracker = tracker @@ -199,7 +213,20 @@ def clear(self): self.tracker.untrack(value) super().clear() + def tree_flatten(self): + # For optree + keys, values = optree.utils.unzip2( + optree.utils.total_order_sorted(self.items(), key=lambda kv: kv[0]) + ) + return values, list(keys), keys + + @classmethod + def tree_unflatten(cls, keys, values): + # For optree + return cls(optree.utils.safe_zip(keys, values)) + +@optree.register_pytree_node_class(namespace="keras") class TrackedSet(set): def __init__(self, values=None, tracker=None): self.tracker = tracker @@ -233,3 +260,12 @@ def clear(self): for value in self: self.tracker.untrack(value) super().clear() + + def tree_flatten(self): + # For optree + return (self, None) + + @classmethod + def tree_unflatten(cls, metadata, children): + # For optree + return cls(children) diff --git a/keras/utils/tree.py b/keras/utils/tree.py index 0ac47db9adc..2788ceb42bb 100644 --- a/keras/utils/tree.py +++ b/keras/utils/tree.py @@ -1,56 +1,389 @@ -import tree +import collections +import collections.abc +import types +import optree +from keras.api_export import keras_export +from keras.backend.config import backend + +# Register backend-specific node classes +if backend() == "tensorflow": + from tensorflow.python.trackable.data_structures import ListWrapper + + optree.register_pytree_node( + ListWrapper, + lambda x: (x, None), + lambda metadata, children: ListWrapper(list(children)), + namespace="keras", + ) + + +@keras_export("keras.tree.is_nested") def is_nested(structure): - return tree.is_nested(structure) + """Checks if a given structure is nested. + + Examples: + + >>> keras.tree.is_nested(42) + False + >>> keras.tree.is_nested({"foo": 42}) + True + + Args: + structure: A structure to check. + + Returns: + `True` if a given structure is nested, i.e. is a sequence, a mapping, + or a namedtuple, and `False` otherwise. + """ + return not optree.tree_is_leaf( + structure, none_is_leaf=True, namespace="keras" + ) + + +@keras_export("keras.tree.traverse") +def traverse(func, structure, top_down=True): + """Traverses the given nested structure, applying the given function. + + The traversal is depth-first. If `top_down` is True (default), parents + are returned before their children (giving the option to avoid traversing + into a sub-tree). + + Examples: + >>> v = [] + >>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=True) + [(1, 2), [3], {'a': 4}] + >>> v + [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4] + >>> v = [] + >>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=False) + [(1, 2), [3], {'a': 4}] + >>> v + [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]] + + Args: + func: The function to be applied to each sub-nest of the structure. + + When traversing top-down: + If `func(subtree) is None` the traversal continues into the + sub-tree. + If `func(subtree) is not None` the traversal does not continue + into the sub-tree. The sub-tree will be replaced by `func(subtree)` + in the returned structure (to replace the sub-tree with `None`, use + the special value `_MAP_TO_NONE`). + + When traversing bottom-up: + If `func(subtree) is None` the traversed sub-tree is returned + unaltered. + If `func(subtree) is not None` the sub-tree will be replaced by + `func(subtree)` in the returned structure (to replace the sub-tree + with None, use the special value `_MAP_TO_NONE`). + + structure: The structure to traverse. + top_down: If True, parent structures will be visited before their + children. + + Returns: + The structured output from the traversal. + """ + + # From https://github.com/google/jax/pull/19695 + def traverse_children(): + children, treedef = optree.tree_flatten( + structure, + is_leaf=lambda x: x is not structure, + none_is_leaf=True, + namespace="keras", + ) + if treedef.num_nodes == 1 and treedef.num_leaves == 1: + return structure + else: + return optree.tree_unflatten( + treedef, + [traverse(func, c, top_down=top_down) for c in children], + ) + + if top_down: + ret = func(structure) + if ret is None: + return traverse_children() + else: + traversed_structure = traverse_children() + ret = func(traversed_structure) + if ret is None: + return traversed_structure + return None if ret is _MAP_TO_NONE else ret + + +@keras_export("keras.tree.flatten") def flatten(structure): - return tree.flatten(structure) + """Flattens a possibly nested structure into a list. + + In the case of dict instances, the sequence consists of the values, + sorted by key to ensure deterministic behavior. This is true also for + `collections.OrderedDict` instances: their sequence order is + considered. The same convention is followed in `unflatten_as`. + This correctly unflattens dicts and `OrderedDict` after they have been + flattened, or vice-versa. + + Dictionaries with non-sortable keys cannot be flattened. + + Examples: + >>> keras.tree.flatten([[1, 2, 3], [4, [5], [[6]]]]) + [1, 2, 3, 4, 5, 6] + >>> keras.tree.flatten(None) + [None] + >>> keras.tree.flatten(1) + [1] + >>> keras.tree.flatten({100: 'world!', 6: 'Hello'}) + ['Hello', 'world!'] -def map_structure(func, *structures, **kwargs): - return tree.map_structure(func, *structures, **kwargs) + Args: + structure: An arbitrarily nested structure. + + Returns: + A list, the flattened version of the input `structure`. + """ + # optree.tree_flatten returns a pair (leaves, treespec) where the first + # element is a list of leaf values and the second element is a treespec + # representing the structure of the pytree. + leaves, _ = optree.tree_flatten( + structure, none_is_leaf=True, namespace="keras" + ) + return leaves + + +@keras_export("keras.tree.unflatten_as") +def unflatten_as(structure, flat_sequence): + """Unflattens a sequence into a given structure. + + If `structure` is a scalar, `flat_sequence` must be a single-element list; + in this case the return value is ``flat_sequence[0]``. + + If `structure` is or contains a dict instance, the keys will be sorted to + pack the flat sequence in deterministic order. This is true also for + `collections.OrderedDict` instances: their sequence order is considered. + The same convention is followed in `flatten`. This correctly unflattens + dicts and `OrderedDict` after they have been flattened, or vice-versa. + + Dictionaries with non-sortable keys cannot be unflattened. + + Examples: + + >>> keras.tree.unflatten_as([[1, 2], [[3], [4]]], [5, 6, 7, 8]) + [[5, 6], [[7], [8]]] + >>> keras.tree.unflatten_as(None, [1]) + 1 + >>> keras.tree.unflatten_as({1: None, 2: None}, ['Hello', 'world!']) + {1: 'Hello', 2: 'world!'} + + Args: + structure: Arbitrarily nested structure. + flat_sequence: Sequence to unflatten. + + Returns: + `flat_sequence` unflattened into `structure`. + """ + if not is_nested(flat_sequence): + raise TypeError( + f"flat_sequence must be a sequence not a {type(flat_sequence)}:\n" + f"{flat_sequence}" + ) + if not is_nested(structure): + if len(flat_sequence) != 1: + raise ValueError( + "Structure is a scalar but " + f"len(flat_sequence) == {len(flat_sequence)} > 1" + ) + return flat_sequence[0] + structure_spec = optree.tree_structure( + structure, none_is_leaf=True, namespace="keras" + ) + return structure_spec.unflatten(flat_sequence) + + +@keras_export("keras.tree.map_structure") +def map_structure(func, *structures): + """Maps `func` through given structures. + + Examples: + >>> structure = [[1], [2], [3]] + >>> keras.tree.map_structure(lambda v: v**2, structure) + [[1], [4], [9]] + >>> keras.tree.map_structure(lambda x, y: x * y, structure, structure) + [[1], [4], [9]] -def map_structure_up_to(shallow_structure, func, *structures, **kwargs): - return tree.map_structure_up_to( - shallow_structure, func, *structures, **kwargs + >>> Foo = collections.namedtuple('Foo', ['a', 'b']) + >>> structure = Foo(a=1, b=2) + >>> keras.tree.map_structure(lambda v: v * 2, structure) + Foo(a=2, b=4) + + Args: + func: A callable that accepts as many arguments as there are structures. + *structures: Arbitrarily nested structures of the same layout. + + Returns: + A new structure with the same layout as the given ones. + """ + if not callable(func): + raise TypeError(f"`func` must be callable. Received: func={func}") + if not structures: + raise ValueError("Must provide at least one structure") + for other in structures[1:]: + assert_same_structure(structures[0], other, check_types=False) + return optree.tree_map( + func, *structures, none_is_leaf=True, namespace="keras" ) -def traverse(func, structure, top_down=True): - return tree.traverse(func, structure, top_down=top_down) +@keras_export("keras.tree.map_structure_up_to") +def map_structure_up_to(shallow_structure, func, *structures): + """Maps `func` through given structures up to `shallow_structure`. + This is a variant of `map_structure` which only maps the given structures + up to `shallow_structure`. All further nested components are retained as-is. -def assert_same_structure(a, b, check_types=True): - return tree.assert_same_structure(a, b, check_types=check_types) + Examples: + >>> shallow_structure = [None, None] + >>> structure = [[1, 1], [2, 2]] + >>> keras.tree.map_structure_up_to(shallow_structure, len, structure) + [2, 2] -def sequence_like(instance, args): - """Converts the sequence `args` to the same type as `instance`. + >>> shallow_structure = [None, [None, None]] + >>> keras.tree.map_structure_up_to(shallow_structure, str, structure) + ['[1, 1]', ['2', '2']] Args: - instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or - `collections.OrderedDict`. - args: elements to be converted to the `instance` type. + shallow_structure: A structure with layout common to all `structures`. + func: A callable that accepts as many arguments as there are structures. + *structures: Arbitrarily nested structures of the same layout. Returns: - `args` with the type of `instance`. + A new structure with the same layout as `shallow_structure`. + """ + return _map_structure_with_path_up_to( + shallow_structure, + lambda _, *args: func(*args), # Discards path. + *structures, + ) + + +@keras_export("keras.tree.assert_same_structure") +def assert_same_structure(a, b, check_types=True): + """Asserts that two structures are nested in the same way. + + Note that namedtuples with identical name and fields will not be considered + as same structures even `check_types=False`. + + Examples: + + >>> keras.tree.assert_same_structure([(0, 1)], [(2, 3)]) + + >>> Foo = collections.namedtuple('Foo', ['a', 'b']) + >>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b']) + >>> keras.tree.assert_same_structure(Foo(0, 1), Foo(2, 3)) + >>> keras.tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3)) + Traceback (most recent call last): + ... + ValueError: `a` and `b` don't have the same structure. + ... + + Args: + a: an arbitrarily nested structure. + b: an arbitrarily nested structure. + check_types: if `True` (default) types of leaves are checked as well. """ - return tree._sequence_like(instance, args) + a_structure = optree.tree_structure(a, none_is_leaf=True, namespace="keras") + b_structure = optree.tree_structure(b, none_is_leaf=True, namespace="keras") + if a_structure != b_structure: + raise ValueError( + "`a` and `b` don't have the same structure. " + f"Received: structure of a={a_structure}, " + f"structure of b={b_structure}" + ) + if check_types: + type_structure = optree.tree_map( + lambda x, y: type(x) is type(y), + a, + b, + none_is_leaf=True, + namespace="keras", + ) + if not optree.tree_all( + type_structure, none_is_leaf=True, namespace="keras" + ): + raise TypeError( + "The type of the leaves of `a` and `b` doesn't match." + ) +@keras_export("keras.tree.pack_sequence_as") def pack_sequence_as(structure, flat_sequence, sequence_fn=None): - """Implements sequence packing, i.e. nest.pack_sequence_as().""" - is_nested_fn = tree.is_nested - sequence_fn = sequence_fn or tree._sequence_like + """Returns a given flattened sequence packed into a given structure. + + If `structure` is an atom, `flat_sequence` must be a single-item list; in + this case the return value is `flat_sequence[0]`. + + If `structure` is or contains a dict instance, the keys will be sorted to + pack the flat sequence in deterministic order. This is true also for + `OrderedDict` instances: their sequence order is considered. The same + convention is followed in `flatten`. This correctly repacks dicts and + `OrderedDicts` after they have been flattened, or vice-versa. + + Dictionaries with non-sortable keys cannot be flattened. + + Examples: + + >>> structure = {"key3": "", "key1": "", "key2": ""} + >>> flat_sequence = ["value1", "value2", "value3"] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + {"key3": "value3", "key1": "value1", "key2": "value2"} + + >>> structure = (("a", "b"), ("c", "d", "e"), "f") + >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) + + >>> structure = {"key3": {"c": ("alpha", "beta"), "a": ("gamma")}, + ... "key1": {"e": "val1", "d": "val2"}} + >>> flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}} + + >>> structure = ["a"] + >>> flat_sequence = [np.array([[1, 2], [3, 4]])] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + [array([[1, 2], + [3, 4]])] + + >>> structure = ["a"] + >>> flat_sequence = [keras.ops.ones([2, 2])] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + [array([[1., 1.], + [1., 1.]]] + + Args: + structure: Arbitrarily nested structure. + flat_sequence: Flat sequence to pack. + sequence_fn: Defaults to `_sequence_like`. + + Returns: + `flat_sequence` converted to have the same recursive structure as + `structure`. + """ + sequence_fn = sequence_fn or _sequence_like def truncate(value, length): value_str = str(value) return value_str[:length] + (value_str[length:] and "...") - if not is_nested_fn(flat_sequence): + if not is_nested(flat_sequence): raise TypeError( "Attempted to pack value:\n {}\ninto a structure, but found " "incompatible type `{}` instead.".format( @@ -58,7 +391,7 @@ def truncate(value, length): ) ) - if not is_nested_fn(structure): + if not is_nested(structure): if len(flat_sequence) != 1: raise ValueError( "The target structure is of type `{}`\n {}\nHowever the input " @@ -74,13 +407,13 @@ def truncate(value, length): return flat_sequence[0] try: - final_index, packed = packed_nest_with_indices( - structure, flat_sequence, 0, is_nested_fn, sequence_fn + final_index, packed = _packed_nest_with_indices( + structure, flat_sequence, 0, sequence_fn ) if final_index < len(flat_sequence): raise IndexError except IndexError: - flat_structure = tree.flatten(structure) + flat_structure = flatten(structure) if len(flat_structure) != len(flat_sequence): # pylint: disable=raise-missing-from raise ValueError( @@ -92,33 +425,147 @@ def truncate(value, length): return sequence_fn(structure, packed) -def packed_nest_with_indices( - structure, flat, index, is_nested_fn, sequence_fn=None -): - """Helper function for pack_sequence_as. +@keras_export("keras.tree.lists_to_tuples") +def lists_to_tuples(structure): + """Converts `list`s to `tuple`s.""" + + def sequence_fn(instance, args): + if isinstance(instance, list): + return tuple(args) + return _sequence_like(instance, args) - Args: - structure: structure to mimic. - flat: Flattened values to output substructure for. - index: Index at which to start reading from flat. - is_nested_fn: Function used to test if a value should - be treated as a nested structure. - sequence_fn: Function used to generate a new structure instance. + return pack_sequence_as( + structure, flatten(structure), sequence_fn=sequence_fn + ) - Returns: - The tuple (new_index, child), where: - * new_index - the updated index into `flat` - having processed `structure`. - * packed - the subset of `flat` corresponding to `structure`, - having started at `index`, and packed into the same nested - format. - """ + +class _MapToNone: + """A special object used as a sentinel within `traverse`.""" + + def __repr__(self): + return "keras.utils.tree._MAP_TO_NONE" + + +_MAP_TO_NONE = _MapToNone() + + +def _yield_flat_up_to(shallow_tree, input_tree, path=()): + if isinstance(shallow_tree, (str, bytes)) or not ( + isinstance( + shallow_tree, (collections.abc.Mapping, collections.abc.Sequence) + ) + or optree.is_namedtuple(shallow_tree) + ): + yield (path, input_tree) + else: + input_tree = dict(_yield_sorted_items(input_tree)) + for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree): + subpath = path + (shallow_key,) + input_subtree = input_tree[shallow_key] + for leaf_path, leaf_value in _yield_flat_up_to( + shallow_subtree, input_subtree, path=subpath + ): + yield (leaf_path, leaf_value) + + +def _multiyield_flat_up_to(shallow_tree, *input_trees): + """Same as `_yield_flat_up_to`, but takes multiple input trees.""" + zipped_iterators = zip( + *[ + _yield_flat_up_to(shallow_tree, input_tree) + for input_tree in input_trees + ] + ) + try: + for paths_and_values in zipped_iterators: + paths, values = zip(*paths_and_values) + yield paths[:1] + values + except KeyError as e: + paths = locals().get("paths", ((),)) + raise ValueError( + f"Could not find key '{e.args[0]}' in some `input_trees`. " + "Please ensure the structure of all `input_trees` are " + "compatible with `shallow_tree`. The last valid path " + f"yielded was {paths[0]}." + ) from e + + +def _map_structure_with_path_up_to(shallow_structure, func, *structures): + results = [] + for path_and_values in _multiyield_flat_up_to( + shallow_structure, *structures + ): + results.append(func(*path_and_values)) + shallow_structure_spec = optree.tree_structure( + shallow_structure, none_is_leaf=True, namespace="keras" + ) + return shallow_structure_spec.unflatten(results) + + +def _sequence_like(instance, args): + # TODO: Support attrs library + if isinstance(instance, (dict, collections.abc.Mapping)): + # Pack dictionaries in a deterministic order by sorting the keys. + # Notice this means that we ignore the original order of `OrderedDict` + # instances. This is intentional, to avoid potential bugs caused by + # mixing ordered and plain dicts (e.g., flattening a dict but using a + # corresponding `OrderedDict` to pack it back). + result = dict(zip(sorted(instance), args)) + keys_and_values = ((key, result[key]) for key in instance) + if isinstance(instance, collections.defaultdict): + # `defaultdict` requires a default factory as the first argument. + return type(instance)(instance.default_factory, keys_and_values) + elif isinstance(instance, types.MappingProxyType): + # MappingProxyType requires a dict to proxy to. + return type(instance)(dict(keys_and_values)) + else: + return type(instance)(keys_and_values) + elif isinstance(instance, collections.abc.MappingView): + # We can't directly construct mapping views, so we create a list instead + return list(args) + elif optree.is_namedtuple(instance): + instance_type = type(instance) + try: + return instance_type(*args) + except Exception as e: + raise TypeError( + f"Couldn't traverse {instance!r} with arguments {args}" + ) from e + else: + # Not a namedtuple + return type(instance)(args) + + +def _yield_sorted_items(iterable): + # TODO: Support attrs library + if isinstance(iterable, collections.abc.Mapping): + # Iterate through dictionaries in a deterministic order by sorting the + # keys. Notice this means that we ignore the original order of + # `OrderedDict` instances. This is intentional, to avoid potential bugs + # caused by mixing ordered and plain dicts (e.g., flattening a dict but + # using a corresponding `OrderedDict` to pack it back). + for key in sorted(iterable): + yield key, iterable[key] + elif optree.is_namedtuple(iterable): + for field in iterable._fields: + yield (field, getattr(iterable, field)) + else: + for item in enumerate(iterable): + yield item + + +def _yield_value(iterable): + for _, v in _yield_sorted_items(iterable): + yield v + + +def _packed_nest_with_indices(structure, flat, index, sequence_fn=None): packed = [] - sequence_fn = sequence_fn or tree._sequence_like - for s in yield_value(structure): - if is_nested_fn(s): - new_index, child = packed_nest_with_indices( - s, flat, index, is_nested_fn, sequence_fn + sequence_fn = sequence_fn or _sequence_like + for s in _yield_value(structure): + if is_nested(s): + new_index, child = _packed_nest_with_indices( + s, flat, index, sequence_fn ) packed.append(sequence_fn(s, child)) index = new_index @@ -126,21 +573,3 @@ def packed_nest_with_indices( packed.append(flat[index]) index += 1 return index, packed - - -def yield_value(iterable): - for _, v in tree._yield_sorted_items(iterable): - yield v - - -def lists_to_tuples(structure): - def sequence_fn(instance, args): - if isinstance(instance, list): - return tuple(args) - return tree._sequence_like(instance, args) - - return pack_sequence_as( - structure, - tree.flatten(structure), - sequence_fn=sequence_fn, - ) diff --git a/keras/utils/tree_test.py b/keras/utils/tree_test.py new file mode 100644 index 00000000000..a5ca84dab5c --- /dev/null +++ b/keras/utils/tree_test.py @@ -0,0 +1,291 @@ +import collections + +import numpy as np + +from keras import ops +from keras import testing +from keras.utils import tree + +STRUCTURE1 = (((1, 2), 3), 4, (5, 6)) +STRUCTURE2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) +STRUCTURE_DIFFERENT_NUM_ELEMENTS = ("spam", "eggs") +STRUCTURE_DIFFERENT_NESTING = (((1, 2), 3), 4, 5, (6,)) + + +class TreeTest(testing.TestCase): + def test_is_nested(self): + self.assertFalse(tree.is_nested("1234")) + self.assertFalse(tree.is_nested(b"1234")) + self.assertFalse(tree.is_nested(bytearray("1234", "ascii"))) + self.assertTrue(tree.is_nested([1, 3, [4, 5]])) + self.assertTrue(tree.is_nested(((7, 8), (5, 6)))) + self.assertTrue(tree.is_nested([])) + self.assertTrue(tree.is_nested({"a": 1, "b": 2})) + self.assertFalse(tree.is_nested(set([1, 2]))) + ones = np.ones([2, 3]) + self.assertFalse(tree.is_nested(ones)) + self.assertFalse(tree.is_nested(np.tanh(ones))) + self.assertFalse(tree.is_nested(np.ones((4, 5)))) + + def test_flatten_and_unflatten(self): + structure = ((3, 4), 5, (6, 7, (9, 10), 8)) + flat = ["a", "b", "c", "d", "e", "f", "g", "h"] + + self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) + self.assertEqual( + tree.unflatten_as(structure, flat), + (("a", "b"), "c", ("d", "e", ("f", "g"), "h")), + ) + point = collections.namedtuple("Point", ["x", "y"]) + structure = (point(x=4, y=2), ((point(x=1, y=0),),)) + flat = [4, 2, 1, 0] + self.assertEqual(tree.flatten(structure), flat) + restructured_from_flat = tree.unflatten_as(structure, flat) + self.assertEqual(restructured_from_flat, structure) + self.assertEqual(restructured_from_flat[0].x, 4) + self.assertEqual(restructured_from_flat[0].y, 2) + self.assertEqual(restructured_from_flat[1][0][0].x, 1) + self.assertEqual(restructured_from_flat[1][0][0].y, 0) + + self.assertEqual([5], tree.flatten(5)) + self.assertEqual([np.array([5])], tree.flatten(np.array([5]))) + + self.assertEqual("a", tree.unflatten_as(5, ["a"])) + self.assertEqual( + np.array([5]), tree.unflatten_as("scalar", [np.array([5])]) + ) + + with self.assertRaisesRegex(ValueError, "Structure is a scalar"): + tree.unflatten_as("scalar", [4, 5]) + with self.assertRaisesRegex(TypeError, "flat_sequence"): + tree.unflatten_as([4, 5], "bad_sequence") + with self.assertRaises(ValueError): + tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"]) + + self.assertEqual( + tree.unflatten_as({1: None, 2: None}, ["Hello", "world!"]), + {1: "Hello", 2: "world!"}, + ) + + def test_flatten_dict_order(self): + ordered = collections.OrderedDict( + [("d", 3), ("b", 1), ("a", 0), ("c", 2)] + ) + plain = {"d": 3, "b": 1, "a": 0, "c": 2} + ordered_flat = tree.flatten(ordered) + plain_flat = tree.flatten(plain) + self.assertEqual([3, 1, 0, 2], ordered_flat) + self.assertEqual([0, 1, 2, 3], plain_flat) + + def test_unflatten_dict_order(self): + ordered = collections.OrderedDict( + [("d", 0), ("b", 0), ("a", 0), ("c", 0)] + ) + plain = {"d": 0, "b": 0, "a": 0, "c": 0} + seq = [0, 1, 2, 3] + ordered_reconstruction = tree.unflatten_as(ordered, seq) + plain_reconstruction = tree.unflatten_as(plain, seq) + self.assertEqual( + collections.OrderedDict([("d", 0), ("b", 1), ("a", 2), ("c", 3)]), + ordered_reconstruction, + ) + self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) + + def test_map_structure(self): + structure2 = (((7, 8), 9), 10, (11, 12)) + structure1_plus1 = tree.map_structure(lambda x: x + 1, STRUCTURE1) + tree.assert_same_structure(STRUCTURE1, structure1_plus1) + self.assertAllEqual([2, 3, 4, 5, 6, 7], tree.flatten(structure1_plus1)) + structure1_plus_structure2 = tree.map_structure( + lambda x, y: x + y, STRUCTURE1, structure2 + ) + self.assertEqual( + (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), + structure1_plus_structure2, + ) + + self.assertEqual(3, tree.map_structure(lambda x: x - 1, 4)) + + self.assertEqual(7, tree.map_structure(lambda x, y: x + y, 3, 4)) + + # Empty structures + self.assertEqual((), tree.map_structure(lambda x: x + 1, ())) + self.assertEqual([], tree.map_structure(lambda x: x + 1, [])) + self.assertEqual({}, tree.map_structure(lambda x: x + 1, {})) + empty_nt = collections.namedtuple("empty_nt", "") + self.assertEqual( + empty_nt(), tree.map_structure(lambda x: x + 1, empty_nt()) + ) + + # This is checking actual equality of types, empty list != empty tuple + self.assertNotEqual((), tree.map_structure(lambda x: x + 1, [])) + + with self.assertRaisesRegex(TypeError, "callable"): + tree.map_structure("bad", structure1_plus1) + with self.assertRaisesRegex(ValueError, "at least one structure"): + tree.map_structure(lambda x: x) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.map_structure(lambda x, y: None, 3, (3,)) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) + + structure1_list = [[[1, 2], 3], 4, [5, 6]] + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list) + + def test_map_structure_up_to(self): + # Named tuples. + ab_tuple = collections.namedtuple("ab_tuple", "a, b") + op_tuple = collections.namedtuple("op_tuple", "add, mul") + inp_val = ab_tuple(a=2, b=3) + inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) + out = tree.map_structure_up_to( + inp_val, + lambda val, ops: (val + ops.add) * ops.mul, + inp_val, + inp_ops, + ) + self.assertEqual(out.a, 6) + self.assertEqual(out.b, 15) + + # Lists. + data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] + name_list = ["evens", ["odds", "primes"]] + out = tree.map_structure_up_to( + name_list, + lambda name, sec: "first_{}_{}".format(len(sec), name), + name_list, + data_list, + ) + self.assertEqual( + out, ["first_4_evens", ["first_5_odds", "first_3_primes"]] + ) + + def test_assert_same_structure(self): + tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False) + tree.assert_same_structure("abc", 1.0, check_types=False) + tree.assert_same_structure(b"abc", 1.0, check_types=False) + tree.assert_same_structure("abc", 1.0, check_types=False) + tree.assert_same_structure( + bytearray("abc", "ascii"), 1.0, check_types=False + ) + tree.assert_same_structure("abc", np.array([0, 1]), check_types=False) + + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure( + STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS + ) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure([0, 1], np.array([0, 1])) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure(0, [0, 1]) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure((0, 1), [0, 1]) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure([[3], 4], [3, [4]]) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure({"a": 1}, {"b": 1}) + structure1_list = [[[1, 2], 3], 4, [5, 6]] + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure(STRUCTURE1, structure1_list) + tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False) + with self.assertRaisesRegex(ValueError, "have the same structure"): + tree.assert_same_structure( + STRUCTURE1, structure1_list, check_types=False + ) + + def test_pack_sequence_as(self): + structure = {"key3": "", "key1": "", "key2": ""} + flat_sequence = ["value1", "value2", "value3"] + self.assertEqual( + tree.pack_sequence_as(structure, flat_sequence), + {"key3": "value3", "key1": "value1", "key2": "value2"}, + ) + structure = (("a", "b"), ("c", "d", "e"), "f") + flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + self.assertEqual( + tree.pack_sequence_as(structure, flat_sequence), + ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0), + ) + structure = { + "key3": {"c": ("alpha", "beta"), "a": ("gamma")}, + "key1": {"e": "val1", "d": "val2"}, + } + flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0] + self.assertEqual( + tree.pack_sequence_as(structure, flat_sequence), + { + "key3": {"c": (1.0, 2.0), "a": 3.0}, + "key1": {"e": "val1", "d": "val2"}, + }, + ) + structure = ["a"] + flat_sequence = [np.array([[1, 2], [3, 4]])] + self.assertAllClose( + tree.pack_sequence_as(structure, flat_sequence), + [np.array([[1, 2], [3, 4]])], + ) + structure = ["a"] + flat_sequence = [ops.ones([2, 2])] + self.assertAllClose( + tree.pack_sequence_as(structure, flat_sequence), + [ops.ones([2, 2])], + ) + + with self.assertRaisesRegex(TypeError, "Attempted to pack value:"): + structure = ["a"] + flat_sequence = 1 + tree.pack_sequence_as(structure, flat_sequence) + with self.assertRaisesRegex(ValueError, "The target structure is of"): + structure = "a" + flat_sequence = [1, 2] + tree.pack_sequence_as(structure, flat_sequence) + + def test_lists_to_tuples(self): + structure = [1, 2, 3] + self.assertEqual(tree.lists_to_tuples(structure), (1, 2, 3)) + structure = [[1], [2, 3]] + self.assertEqual(tree.lists_to_tuples(structure), ((1,), (2, 3))) + structure = [[1], [2, [3]]] + self.assertEqual(tree.lists_to_tuples(structure), ((1,), (2, (3,)))) + + def test_traverse(self): + # Lists to tuples + structure = [(1, 2), [3], {"a": [4]}] + self.assertEqual( + ((1, 2), (3,), {"a": (4,)}), + tree.traverse( + lambda x: tuple(x) if isinstance(x, list) else x, + structure, + top_down=False, + ), + ) + # EarlyTermination + structure = [(1, [2]), [3, (4, 5, 6)]] + visited = [] + + def visit(x): + visited.append(x) + return "X" if isinstance(x, tuple) and len(x) > 2 else None + + output = tree.traverse(visit, structure) + self.assertEqual([(1, [2]), [3, "X"]], output) + self.assertEqual( + [ + [(1, [2]), [3, (4, 5, 6)]], + (1, [2]), + 1, + [2], + 2, + [3, (4, 5, 6)], + 3, + (4, 5, 6), + ], + visited, + ) diff --git a/requirements-common.txt b/requirements-common.txt index 723a7b5ab31..5d15f7ac615 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -15,5 +15,5 @@ google tensorboard-plugin-profile rich build -dm-tree +optree pytest-cov diff --git a/setup.py b/setup.py index a9f0d38818f..1a3fbf43e8b 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ def get_version(rel_path): "rich", "namex", "h5py", - "dm-tree", + "optree", "ml-dtypes", ], # Supported Python versions