Skip to content

Commit

Permalink
Replace dm-tree with optree (#19306)
Browse files Browse the repository at this point in the history
* Refactor `keras.utils.tree`

* Fix tests

* Replace `dm-tree` with `optree`

* Eliminate `tf.nest`

* Resolve comments

* Fix merge conflicts

* Update exporting path
  • Loading branch information
james77777778 authored Mar 15, 2024
1 parent 3fcb38c commit e2b43e2
Show file tree
Hide file tree
Showing 12 changed files with 868 additions and 115 deletions.
7 changes: 3 additions & 4 deletions keras/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from keras.backend.common.name_scope import name_scope as base_name_scope
from keras.backend.common.stateless_scope import StatelessScope
from keras.backend.common.stateless_scope import in_stateless_scope
from keras.utils import tree
from keras.utils.naming import auto_name

SUPPORTS_SPARSE_TENSORS = True
Expand Down Expand Up @@ -189,7 +190,7 @@ def convert_keras_tensor_to_tf(x):
)
return x

args, kwargs = tf.nest.map_structure(
args, kwargs = tree.map_structure(
convert_keras_tensor_to_tf, (args, kwargs)
)
tf_out = fn(*args, **kwargs)
Expand All @@ -201,9 +202,7 @@ def convert_tf_to_keras_tensor(x):
)
return x

output_spec = tf.nest.map_structure(
convert_tf_to_keras_tensor, tf_out
)
output_spec = tree.map_structure(convert_tf_to_keras_tensor, tf_out)
return output_spec


Expand Down
13 changes: 7 additions & 6 deletions keras/backend/tensorflow/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras.backend.tensorflow.trackable import KerasAutoTrackable
from keras.utils import tf_utils
from keras.utils import tracking
from keras.utils import tree


class TFLayer(KerasAutoTrackable):
Expand All @@ -27,16 +28,16 @@ def _set_save_spec(self, inputs, args=None, kwargs=None):
if self._saved_model_inputs_spec is not None:
return # Already set.

inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or [])
inputs_spec = tree.map_structure(tf_utils.get_tensor_spec, inputs)
args_spec = tree.map_structure(tf_utils.get_tensor_spec, args or [])
kwargs_spec = {}
# Filter out non-tensor arguments from kwargs.
for key, kwarg in kwargs.items():
flat_kwarg = tf.nest.flatten(kwarg)
flat_kwarg = tree.flatten(kwarg)
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
if any(s is None for s in flat_specs):
continue
kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs)
kwargs_spec[key] = tree.pack_sequence_as(kwarg, flat_specs)

self._saved_model_inputs_spec = inputs_spec
self._saved_model_arg_spec = (
Expand Down Expand Up @@ -94,7 +95,7 @@ def _default_save_signature(self):

if inputs is not None:
input_signature = [
tf.nest.map_structure(
tree.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
inputs,
)
Expand All @@ -108,7 +109,7 @@ def _default_save_signature(self):
]
else:
input_signature = [
tf.nest.map_structure(
tree.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
shapes_dict,
)
Expand Down
23 changes: 11 additions & 12 deletions keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.backend.common.backend_utils import to_tuple_or_list
from keras.backend.tensorflow import sparse
from keras.backend.tensorflow.core import convert_to_tensor
from keras.utils import tree


@sparse.elementwise_binary_union(tf.sparse.add)
Expand Down Expand Up @@ -95,7 +96,7 @@ def _normalize_einsum_subscripts(subscripts):


def einsum(subscripts, *operands, **kwargs):
operands = tf.nest.map_structure(convert_to_tensor, operands)
operands = tree.map_structure(convert_to_tensor, operands)
subscripts = _normalize_einsum_subscripts(subscripts)

def is_valid_for_custom_ops(subscripts, *operands):
Expand Down Expand Up @@ -240,15 +241,15 @@ def use_custom_ops(subscripts, *operands, output_type):
# output_type="int32"
if "int" in compute_dtype and output_type is None:
compute_dtype = config.floatx()
operands = tf.nest.map_structure(
operands = tree.map_structure(
lambda x: tf.cast(x, compute_dtype), operands
)
result = use_custom_ops(subscripts, *operands, output_type=output_type)
else:
# TODO: tf.einsum doesn't support integer dtype with gpu
if "int" in compute_dtype:
compute_dtype = config.floatx()
operands = tf.nest.map_structure(
operands = tree.map_structure(
lambda x: tf.cast(x, compute_dtype), operands
)
result = tf.einsum(subscripts, *operands, **kwargs)
Expand Down Expand Up @@ -763,11 +764,11 @@ def concatenate(xs, axis=0):
)
for x in xs
]
xs = tf.nest.map_structure(convert_to_tensor, xs)
xs = tree.map_structure(convert_to_tensor, xs)
dtype_set = set([x.dtype for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs)
xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs)
return tf.concat(xs, axis=axis)


Expand Down Expand Up @@ -872,7 +873,7 @@ def digitize(x, bins):
bins = list(bins)

# bins must be float type
bins = tf.nest.map_structure(lambda x: float(x), bins)
bins = tree.map_structure(lambda x: float(x), bins)

# TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8
# int16, uint8, uint16, uint32
Expand Down Expand Up @@ -1023,7 +1024,7 @@ def hstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
rank = tf.rank(xs[0])
return tf.cond(
tf.equal(rank, 1),
Expand Down Expand Up @@ -1328,9 +1329,7 @@ def ndim(x):
def nonzero(x):
x = convert_to_tensor(x)
result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1)
return tf.nest.map_structure(
lambda indices: tf.cast(indices, "int32"), result
)
return tree.map_structure(lambda indices: tf.cast(indices, "int32"), result)


def not_equal(x1, x2):
Expand Down Expand Up @@ -1620,7 +1619,7 @@ def stack(x, axis=0):
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x)
x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x)
return tf.stack(x, axis=axis)


Expand Down Expand Up @@ -1807,7 +1806,7 @@ def vstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
return tf.concat(xs, axis=0)


Expand Down
14 changes: 7 additions & 7 deletions keras/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def multi_step_on_data(data):
outputs = one_step_on_data_distributed(data[:1])
for single_step_data in data[1:]:
step_outputs = one_step_on_data_distributed([single_step_data])
outputs = tf.nest.map_structure(
outputs = tree.map_structure(
lambda t1, t2: concat([t1, t2]), outputs, step_outputs
)
return outputs
Expand Down Expand Up @@ -473,7 +473,7 @@ def predict(

def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
outputs = tree.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
Expand Down Expand Up @@ -521,7 +521,7 @@ def get_data(iterator):
outputs = tree.map_structure_up_to(
batch_outputs, potentially_ragged_concat, outputs
)
return tf.nest.map_structure(convert_to_np_if_not_ragged, outputs)
return tree.map_structure(convert_to_np_if_not_ragged, outputs)

def train_on_batch(
self,
Expand Down Expand Up @@ -549,7 +549,7 @@ def data():
yield (x, y, sample_weight)

logs = self.train_function(data())
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
logs = tree.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
Expand All @@ -568,15 +568,15 @@ def data():
yield (x, y, sample_weight)

logs = self.test_function(data())
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
logs = tree.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

def predict_on_batch(self, x):
self.make_predict_function()
batch_outputs = self.predict_function([(x,)])
batch_outputs = tf.nest.map_structure(
batch_outputs = tree.map_structure(
convert_to_np_if_not_ragged, batch_outputs
)
return batch_outputs
Expand Down Expand Up @@ -771,7 +771,7 @@ def _reduce(v):
f"Received: reduction={reduction}."
)

return tf.nest.map_structure(_reduce, values)
return tree.map_structure(_reduce, values)


def _multi_worker_concat(v, strategy):
Expand Down
21 changes: 10 additions & 11 deletions keras/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.models import Functional
from keras.models import Sequential
from keras.utils import io_utils
from keras.utils import tree
from keras.utils.module_utils import tensorflow as tf


Expand Down Expand Up @@ -143,16 +144,16 @@ def track(self, resource):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
if backend.backend() == "jax":
self._tf_trackable.variables += tf.nest.flatten(
tf.nest.map_structure(tf.Variable, resource.variables)
self._tf_trackable.variables += tree.flatten(
tree.map_structure(tf.Variable, resource.variables)
)
self._tf_trackable.trainable_variables += tf.nest.flatten(
tf.nest.map_structure(
self._tf_trackable.trainable_variables += tree.flatten(
tree.map_structure(
tf.Variable, resource.trainable_variables
)
)
self._tf_trackable.non_trainable_variables += tf.nest.flatten(
tf.nest.map_structure(
self._tf_trackable.non_trainable_variables += tree.flatten(
tree.map_structure(
tf.Variable, resource.non_trainable_variables
)
)
Expand Down Expand Up @@ -362,9 +363,7 @@ def add_variable_collection(self, name, variables):
f"{list(set(type(v) for v in variables))}"
)
if backend.backend() == "jax":
variables = tf.nest.flatten(
tf.nest.map_structure(tf.Variable, variables)
)
variables = tree.flatten(tree.map_structure(tf.Variable, variables))
setattr(self._tf_trackable, name, list(variables))

def write_out(self, filepath, options=None):
Expand Down Expand Up @@ -470,7 +469,7 @@ def _convert_jax2tf_function(self, fn, input_signature):

def _spec_to_poly_shape(self, spec):
if isinstance(spec, (dict, list)):
return tf.nest.map_structure(self._spec_to_poly_shape, spec)
return tree.map_structure(self._spec_to_poly_shape, spec)
spec_shape = spec.shape
spec_shape = str(spec_shape).replace("None", "b")
return spec_shape
Expand Down Expand Up @@ -500,7 +499,7 @@ def export_model(model, filepath):
export_archive = ExportArchive()
export_archive.track(model)
if isinstance(model, (Functional, Sequential)):
input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs)
input_signature = tree.map_structure(_make_tensor_spec, model.inputs)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
export_archive.add_endpoint("serve", model.__call__, input_signature)
Expand Down
2 changes: 1 addition & 1 deletion keras/models/cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _clone_layer(layer):
)
try:
tree.assert_same_structure(input_tensors, model.input)
except TypeError as e:
except (ValueError, TypeError) as e:
raise ValueError(
"`input_tensors` must have the same structure as model.input"
f"\nReference structure: {model.input}"
Expand Down
5 changes: 2 additions & 3 deletions keras/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,9 +789,8 @@ def test_cond_check_output_spec_list_tuple(self):

def test_cond_check_output_spec_other_types(self):
cond_op = core.Cond()
# Create mock objects with dtype and shape attributes
mock_spec1 = Mock(dtype="float32", shape=(2, 2))
mock_spec2 = Mock(dtype="float32", shape=(2, 2))
mock_spec1 = KerasTensor(shape=(2, 2), dtype="float32")
mock_spec2 = KerasTensor(shape=(2, 2), dtype="float32")
self.assertTrue(cond_op._check_output_spec(mock_spec1, mock_spec2))

def test_cond_check_output_spec_none(self):
Expand Down
36 changes: 36 additions & 0 deletions keras/utils/tracking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from functools import wraps

import optree
import optree.utils

from keras.backend.common.global_state import get_global_attribute
from keras.backend.common.global_state import set_global_attribute
from keras.utils import python_utils
Expand Down Expand Up @@ -110,6 +113,7 @@ def add_to_store(self, store_name, value):
self.stored_ids[store_name].add(id(value))


@optree.register_pytree_node_class(namespace="keras")
class TrackedList(list):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -160,7 +164,17 @@ def __delitem__(self, index):
if self.tracker:
self.tracker.untrack(value)

def tree_flatten(self):
# For optree
return (self, None)

@classmethod
def tree_unflatten(cls, metadata, children):
# For optree
return cls(children)


@optree.register_pytree_node_class(namespace="keras")
class TrackedDict(dict):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -199,7 +213,20 @@ def clear(self):
self.tracker.untrack(value)
super().clear()

def tree_flatten(self):
# For optree
keys, values = optree.utils.unzip2(
optree.utils.total_order_sorted(self.items(), key=lambda kv: kv[0])
)
return values, list(keys), keys

@classmethod
def tree_unflatten(cls, keys, values):
# For optree
return cls(optree.utils.safe_zip(keys, values))


@optree.register_pytree_node_class(namespace="keras")
class TrackedSet(set):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -233,3 +260,12 @@ def clear(self):
for value in self:
self.tracker.untrack(value)
super().clear()

def tree_flatten(self):
# For optree
return (self, None)

@classmethod
def tree_unflatten(cls, metadata, children):
# For optree
return cls(children)
Loading

0 comments on commit e2b43e2

Please sign in to comment.