diff --git a/keras/backend/jax/trainer.py b/keras/backend/jax/trainer.py index fa585de43f1..a0eaa72d60f 100644 --- a/keras/backend/jax/trainer.py +++ b/keras/backend/jax/trainer.py @@ -62,42 +62,6 @@ def compute_loss_and_updates( loss = self.optimizer.scale_loss(loss) return loss, (unscaled_loss, y_pred, non_trainable_variables) - def _eager_build(self, data_batch): - model_unbuilt = not all(layer.built for layer in self._flatten_layers()) - compile_metrics_unbuilt = ( - self._compile_metrics is not None - and not self._compile_metrics.built - ) - if model_unbuilt or compile_metrics_unbuilt: - - def _convert_data_to_spec(d): - if d is None: - return None - return backend.KerasTensor(d.shape, d.dtype) - - data_spec = tree.map_structure(_convert_data_to_spec, data_batch) - ( - x_spec, - y_spec, - sample_weight_spec, - ) = data_adapter_utils.unpack_x_y_sample_weight(data_spec) - # Note that this __call__ run the forward path and trigger variable - # creation. - y_pred_spec = backend.compute_output_spec(self.__call__, x_spec) - if compile_metrics_unbuilt: - # This will trigger the metric variable creation. - backend.compute_output_spec( - self.compute_metrics, - x_spec, - y_spec, - y_pred_spec, - sample_weight=sample_weight_spec, - ) - - if self.optimizer is not None and not self.optimizer.built: - # Build optimizer - self.optimizer.build(self.trainable_variables) - def train_step(self, state, data): ( trainable_variables, @@ -383,20 +347,7 @@ def fit( steps_per_execution=self.steps_per_execution, ) - needs_building = ( - not all(layer.built for layer in self._flatten_layers()) - or not self.optimizer.built - or ( - self._compile_metrics is not None - and not self._compile_metrics.built - ) - ) - if needs_building: - # Build the model on one batch of data. - for _, data in epoch_iterator.enumerate_epoch(return_type="np"): - data_batch = data[0] - self._eager_build(data_batch) - break + self._symbolic_build(iterator=epoch_iterator) # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -552,18 +503,7 @@ def evaluate( steps_per_execution=self.steps_per_execution, ) - needs_building = not all( - layer.built for layer in self._flatten_layers() - ) or ( - self._compile_metrics is not None - and not self._compile_metrics.built - ) - if needs_building: - # Build the model on one batch of data. - for _, data in epoch_iterator.enumerate_epoch(return_type="np"): - data_batch = data[0] - self._eager_build(data_batch) - break + self._symbolic_build(iterator=epoch_iterator) # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -729,7 +669,7 @@ def train_on_batch( data = self._distribute_data(data) # Maybe build model - self._eager_build(data) + self._symbolic_build(data_batch=data) self._record_training_state_sharding_spec() self.make_train_function() @@ -781,7 +721,7 @@ def test_on_batch( data = (x, y, sample_weight) data = self._distribute_data(data) # Maybe build model - self._eager_build(data) + self._symbolic_build(data_batch=data) self._record_training_state_sharding_spec() self.make_test_function() diff --git a/keras/backend/torch/trainer.py b/keras/backend/torch/trainer.py index 6f634873dd2..a50d9f87380 100644 --- a/keras/backend/torch/trainer.py +++ b/keras/backend/torch/trainer.py @@ -8,9 +8,6 @@ from keras import backend from keras import callbacks as callbacks_module from keras import optimizers as optimizers_module -from keras.backend.common import standardize_dtype -from keras.backend.common.keras_tensor import KerasTensor -from keras.backend.torch.core import is_tensor from keras.trainers import data_adapters from keras.trainers import epoch_iterator from keras.trainers import trainer as base_trainer @@ -163,55 +160,6 @@ def one_step_on_data(data): else: self.predict_function = one_step_on_data - def _symbolic_build(self, data_batch): - model_unbuilt = not all(layer.built for layer in self._flatten_layers()) - compile_metrics_unbuilt = ( - self._compile_metrics is not None - and not self._compile_metrics.built - ) - if model_unbuilt or compile_metrics_unbuilt: - # Create symbolic tensors matching an input batch. - - def to_symbolic_input(v): - if is_tensor(v): - return KerasTensor(v.shape, standardize_dtype(v.dtype)) - return v - - data_batch = tree.map_structure(to_symbolic_input, data_batch) - ( - x, - y, - sample_weight, - ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) - # Build all model state with `backend.compute_output_spec`. - try: - y_pred = backend.compute_output_spec(self, x) - except Exception as e: - raise RuntimeError( - "Unable to automatically build the model. " - "Please build it yourself before calling " - "fit/evaluate/predict. " - "A model is 'built' when its variables have " - "been created and its `self.built` attribute " - "is True. Usually, calling the model on a batch " - "of data is the right way to build it.\n" - "Exception encountered:\n" - f"'{e}'" - ) - if compile_metrics_unbuilt: - # Build all metric state with `backend.compute_output_spec`. - backend.compute_output_spec( - self.compute_metrics, - x, - y, - y_pred, - sample_weight=sample_weight, - ) - if self.optimizer is not None and not self.optimizer.built: - # Build optimizer - self.optimizer.build(self.trainable_variables) - self._post_build() - @traceback_utils.filter_traceback def fit( self, @@ -270,20 +218,7 @@ def fit( steps_per_execution=self.steps_per_execution, ) - needs_building = ( - not all(layer.built for layer in self._flatten_layers()) - or not self.optimizer.built - or ( - self._compile_metrics is not None - and not self._compile_metrics.built - ) - ) - if needs_building: - # Build the model on one batch of data. - for _, data in epoch_iterator.enumerate_epoch(): - data_batch = data[0] - self._symbolic_build(data_batch) - break + self._symbolic_build(iterator=epoch_iterator) # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -403,12 +338,7 @@ def evaluate( steps_per_execution=self.steps_per_execution, ) - if not all(layer.built for layer in self._flatten_layers()): - # Build the model on one batch of data. - for _, data in epoch_iterator.enumerate_epoch(): - data_batch = data[0] - self._symbolic_build(data_batch) - break + self._symbolic_build(iterator=epoch_iterator) # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -525,7 +455,7 @@ def train_on_batch( data = (x, y, sample_weight) # Maybe build model - self._symbolic_build(data) + self._symbolic_build(data_batch=data) self.make_train_function() logs = self.train_function([data]) @@ -546,7 +476,7 @@ def test_on_batch( data = (x, y, sample_weight) # Maybe build model - self._symbolic_build(data) + self._symbolic_build(data_batch=data) self.make_test_function() logs = self.test_function([data]) diff --git a/keras/trainers/trainer.py b/keras/trainers/trainer.py index b674c0c90ee..433100d4db4 100644 --- a/keras/trainers/trainer.py +++ b/keras/trainers/trainer.py @@ -1,6 +1,8 @@ import platform import warnings +import tree + from keras import backend from keras import metrics as metrics_module from keras import ops @@ -9,6 +11,7 @@ from keras.saving import serialization_lib from keras.trainers.compile_utils import CompileLoss from keras.trainers.compile_utils import CompileMetrics +from keras.trainers.data_adapters import data_adapter_utils from keras.utils import traceback_utils from keras.utils import tracking @@ -878,6 +881,66 @@ def _assert_compile_called(self, method_name=None): msg += f"calling `{method_name}()`." raise ValueError(msg) + def _symbolic_build(self, iterator=None, data_batch=None): + model_unbuilt = not all(layer.built for layer in self._flatten_layers()) + compile_metrics_unbuilt = ( + self._compile_metrics is not None + and not self._compile_metrics.built + ) + optimizer_unbuilt = ( + self.optimizer is not None and not self.optimizer.built + ) + if model_unbuilt or compile_metrics_unbuilt or optimizer_unbuilt: + if data_batch is None: + for _, data in iterator.enumerate_epoch(): + data_batch = data[0] + break + + if model_unbuilt or compile_metrics_unbuilt: + # Create symbolic tensors matching an input batch. + + def to_symbolic_input(v): + if v is None: + return None + return backend.KerasTensor( + v.shape, backend.standardize_dtype(v.dtype) + ) + + data_batch = tree.map_structure(to_symbolic_input, data_batch) + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) + # Build all model state with `backend.compute_output_spec`. + try: + y_pred = backend.compute_output_spec(self, x) + except Exception as e: + raise RuntimeError( + "Unable to automatically build the model. " + "Please build it yourself before calling " + "fit/evaluate/predict. " + "A model is 'built' when its variables have " + "been created and its `self.built` attribute " + "is True. Usually, calling the model on a batch " + "of data is the right way to build it.\n" + "Exception encountered:\n" + f"'{e}'" + ) + if compile_metrics_unbuilt: + # Build all metric state with `backend.compute_output_spec`. + backend.compute_output_spec( + self.compute_metrics, + x, + y, + y_pred, + sample_weight=sample_weight, + ) + if optimizer_unbuilt: + # Build optimizer + self.optimizer.build(self.trainable_variables) + self._post_build() + def model_supports_jit(model): # XLA not supported with TF on MacOS GPU