Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace dm-tree with optree #19306

Merged
merged 9 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions keras/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from keras.backend.common.name_scope import name_scope as base_name_scope
from keras.backend.common.stateless_scope import StatelessScope
from keras.backend.common.stateless_scope import in_stateless_scope
from keras.utils import tree
from keras.utils.naming import auto_name

SUPPORTS_SPARSE_TENSORS = True
Expand Down Expand Up @@ -189,7 +190,7 @@ def convert_keras_tensor_to_tf(x):
)
return x

args, kwargs = tf.nest.map_structure(
args, kwargs = tree.map_structure(
convert_keras_tensor_to_tf, (args, kwargs)
)
tf_out = fn(*args, **kwargs)
Expand All @@ -201,9 +202,7 @@ def convert_tf_to_keras_tensor(x):
)
return x

output_spec = tf.nest.map_structure(
convert_tf_to_keras_tensor, tf_out
)
output_spec = tree.map_structure(convert_tf_to_keras_tensor, tf_out)
return output_spec


Expand Down
13 changes: 7 additions & 6 deletions keras/backend/tensorflow/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras.backend.tensorflow.trackable import KerasAutoTrackable
from keras.utils import tf_utils
from keras.utils import tracking
from keras.utils import tree


class TFLayer(KerasAutoTrackable):
Expand All @@ -27,16 +28,16 @@ def _set_save_spec(self, inputs, args=None, kwargs=None):
if self._saved_model_inputs_spec is not None:
return # Already set.

inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or [])
inputs_spec = tree.map_structure(tf_utils.get_tensor_spec, inputs)
args_spec = tree.map_structure(tf_utils.get_tensor_spec, args or [])
kwargs_spec = {}
# Filter out non-tensor arguments from kwargs.
for key, kwarg in kwargs.items():
flat_kwarg = tf.nest.flatten(kwarg)
flat_kwarg = tree.flatten(kwarg)
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
if any(s is None for s in flat_specs):
continue
kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs)
kwargs_spec[key] = tree.pack_sequence_as(kwarg, flat_specs)

self._saved_model_inputs_spec = inputs_spec
self._saved_model_arg_spec = (
Expand Down Expand Up @@ -94,7 +95,7 @@ def _default_save_signature(self):

if inputs is not None:
input_signature = [
tf.nest.map_structure(
tree.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
inputs,
)
Expand All @@ -108,7 +109,7 @@ def _default_save_signature(self):
]
else:
input_signature = [
tf.nest.map_structure(
tree.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
shapes_dict,
)
Expand Down
23 changes: 11 additions & 12 deletions keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.backend.common.backend_utils import to_tuple_or_list
from keras.backend.tensorflow import sparse
from keras.backend.tensorflow.core import convert_to_tensor
from keras.utils import tree


@sparse.elementwise_binary_union(tf.sparse.add)
Expand Down Expand Up @@ -95,7 +96,7 @@ def _normalize_einsum_subscripts(subscripts):


def einsum(subscripts, *operands, **kwargs):
operands = tf.nest.map_structure(convert_to_tensor, operands)
operands = tree.map_structure(convert_to_tensor, operands)
subscripts = _normalize_einsum_subscripts(subscripts)

def is_valid_for_custom_ops(subscripts, *operands):
Expand Down Expand Up @@ -240,15 +241,15 @@ def use_custom_ops(subscripts, *operands, output_type):
# output_type="int32"
if "int" in compute_dtype and output_type is None:
compute_dtype = config.floatx()
operands = tf.nest.map_structure(
operands = tree.map_structure(
lambda x: tf.cast(x, compute_dtype), operands
)
result = use_custom_ops(subscripts, *operands, output_type=output_type)
else:
# TODO: tf.einsum doesn't support integer dtype with gpu
if "int" in compute_dtype:
compute_dtype = config.floatx()
operands = tf.nest.map_structure(
operands = tree.map_structure(
lambda x: tf.cast(x, compute_dtype), operands
)
result = tf.einsum(subscripts, *operands, **kwargs)
Expand Down Expand Up @@ -763,11 +764,11 @@ def concatenate(xs, axis=0):
)
for x in xs
]
xs = tf.nest.map_structure(convert_to_tensor, xs)
xs = tree.map_structure(convert_to_tensor, xs)
dtype_set = set([x.dtype for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs)
xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs)
return tf.concat(xs, axis=axis)


Expand Down Expand Up @@ -872,7 +873,7 @@ def digitize(x, bins):
bins = list(bins)

# bins must be float type
bins = tf.nest.map_structure(lambda x: float(x), bins)
bins = tree.map_structure(lambda x: float(x), bins)

# TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8
# int16, uint8, uint16, uint32
Expand Down Expand Up @@ -1023,7 +1024,7 @@ def hstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
rank = tf.rank(xs[0])
return tf.cond(
tf.equal(rank, 1),
Expand Down Expand Up @@ -1328,9 +1329,7 @@ def ndim(x):
def nonzero(x):
x = convert_to_tensor(x)
result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1)
return tf.nest.map_structure(
lambda indices: tf.cast(indices, "int32"), result
)
return tree.map_structure(lambda indices: tf.cast(indices, "int32"), result)


def not_equal(x1, x2):
Expand Down Expand Up @@ -1620,7 +1619,7 @@ def stack(x, axis=0):
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x)
x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x)
return tf.stack(x, axis=axis)


Expand Down Expand Up @@ -1807,7 +1806,7 @@ def vstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
return tf.concat(xs, axis=0)


Expand Down
14 changes: 7 additions & 7 deletions keras/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def multi_step_on_data(data):
outputs = one_step_on_data_distributed(data[:1])
for single_step_data in data[1:]:
step_outputs = one_step_on_data_distributed([single_step_data])
outputs = tf.nest.map_structure(
outputs = tree.map_structure(
lambda t1, t2: concat([t1, t2]), outputs, step_outputs
)
return outputs
Expand Down Expand Up @@ -473,7 +473,7 @@ def predict(

def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
outputs = tree.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
Expand Down Expand Up @@ -521,7 +521,7 @@ def get_data(iterator):
outputs = tree.map_structure_up_to(
batch_outputs, potentially_ragged_concat, outputs
)
return tf.nest.map_structure(convert_to_np_if_not_ragged, outputs)
return tree.map_structure(convert_to_np_if_not_ragged, outputs)

def train_on_batch(
self,
Expand Down Expand Up @@ -549,7 +549,7 @@ def data():
yield (x, y, sample_weight)

logs = self.train_function(data())
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
logs = tree.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
Expand All @@ -568,15 +568,15 @@ def data():
yield (x, y, sample_weight)

logs = self.test_function(data())
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
logs = tree.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

def predict_on_batch(self, x):
self.make_predict_function()
batch_outputs = self.predict_function([(x,)])
batch_outputs = tf.nest.map_structure(
batch_outputs = tree.map_structure(
convert_to_np_if_not_ragged, batch_outputs
)
return batch_outputs
Expand Down Expand Up @@ -771,7 +771,7 @@ def _reduce(v):
f"Received: reduction={reduction}."
)

return tf.nest.map_structure(_reduce, values)
return tree.map_structure(_reduce, values)


def _multi_worker_concat(v, strategy):
Expand Down
21 changes: 10 additions & 11 deletions keras/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.models import Functional
from keras.models import Sequential
from keras.utils import io_utils
from keras.utils import tree
from keras.utils.module_utils import tensorflow as tf


Expand Down Expand Up @@ -143,16 +144,16 @@ def track(self, resource):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
if backend.backend() == "jax":
self._tf_trackable.variables += tf.nest.flatten(
tf.nest.map_structure(tf.Variable, resource.variables)
self._tf_trackable.variables += tree.flatten(
tree.map_structure(tf.Variable, resource.variables)
)
self._tf_trackable.trainable_variables += tf.nest.flatten(
tf.nest.map_structure(
self._tf_trackable.trainable_variables += tree.flatten(
tree.map_structure(
tf.Variable, resource.trainable_variables
)
)
self._tf_trackable.non_trainable_variables += tf.nest.flatten(
tf.nest.map_structure(
self._tf_trackable.non_trainable_variables += tree.flatten(
tree.map_structure(
tf.Variable, resource.non_trainable_variables
)
)
Expand Down Expand Up @@ -362,9 +363,7 @@ def add_variable_collection(self, name, variables):
f"{list(set(type(v) for v in variables))}"
)
if backend.backend() == "jax":
variables = tf.nest.flatten(
tf.nest.map_structure(tf.Variable, variables)
)
variables = tree.flatten(tree.map_structure(tf.Variable, variables))
setattr(self._tf_trackable, name, list(variables))

def write_out(self, filepath, options=None):
Expand Down Expand Up @@ -470,7 +469,7 @@ def _convert_jax2tf_function(self, fn, input_signature):

def _spec_to_poly_shape(self, spec):
if isinstance(spec, (dict, list)):
return tf.nest.map_structure(self._spec_to_poly_shape, spec)
return tree.map_structure(self._spec_to_poly_shape, spec)
spec_shape = spec.shape
spec_shape = str(spec_shape).replace("None", "b")
return spec_shape
Expand Down Expand Up @@ -500,7 +499,7 @@ def export_model(model, filepath):
export_archive = ExportArchive()
export_archive.track(model)
if isinstance(model, (Functional, Sequential)):
input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs)
input_signature = tree.map_structure(_make_tensor_spec, model.inputs)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
export_archive.add_endpoint("serve", model.__call__, input_signature)
Expand Down
2 changes: 1 addition & 1 deletion keras/models/cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _clone_layer(layer):
)
try:
tree.assert_same_structure(input_tensors, model.input)
except TypeError as e:
except (ValueError, TypeError) as e:
raise ValueError(
"`input_tensors` must have the same structure as model.input"
f"\nReference structure: {model.input}"
Expand Down
5 changes: 2 additions & 3 deletions keras/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,9 +789,8 @@ def test_cond_check_output_spec_list_tuple(self):

def test_cond_check_output_spec_other_types(self):
cond_op = core.Cond()
# Create mock objects with dtype and shape attributes
mock_spec1 = Mock(dtype="float32", shape=(2, 2))
mock_spec2 = Mock(dtype="float32", shape=(2, 2))
mock_spec1 = KerasTensor(shape=(2, 2), dtype="float32")
mock_spec2 = KerasTensor(shape=(2, 2), dtype="float32")
self.assertTrue(cond_op._check_output_spec(mock_spec1, mock_spec2))

def test_cond_check_output_spec_none(self):
Expand Down
36 changes: 36 additions & 0 deletions keras/utils/tracking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from functools import wraps

import optree
import optree.utils

from keras.backend.common.global_state import get_global_attribute
from keras.backend.common.global_state import set_global_attribute
from keras.utils import python_utils
Expand Down Expand Up @@ -110,6 +113,7 @@ def add_to_store(self, store_name, value):
self.stored_ids[store_name].add(id(value))


@optree.register_pytree_node_class(namespace="keras")
james77777778 marked this conversation as resolved.
Show resolved Hide resolved
class TrackedList(list):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -160,7 +164,17 @@ def __delitem__(self, index):
if self.tracker:
self.tracker.untrack(value)

def tree_flatten(self):
# For optree
return (self, None)

@classmethod
def tree_unflatten(cls, metadata, children):
# For optree
return cls(children)


@optree.register_pytree_node_class(namespace="keras")
class TrackedDict(dict):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -199,7 +213,20 @@ def clear(self):
self.tracker.untrack(value)
super().clear()

def tree_flatten(self):
# For optree
keys, values = optree.utils.unzip2(
optree.utils.total_order_sorted(self.items(), key=lambda kv: kv[0])
)
return values, list(keys), keys

@classmethod
def tree_unflatten(cls, keys, values):
# For optree
return cls(optree.utils.safe_zip(keys, values))


@optree.register_pytree_node_class(namespace="keras")
class TrackedSet(set):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -233,3 +260,12 @@ def clear(self):
for value in self:
self.tracker.untrack(value)
super().clear()

def tree_flatten(self):
# For optree
return (self, None)

@classmethod
def tree_unflatten(cls, metadata, children):
# For optree
return cls(children)
Loading