Skip to content

Commit

Permalink
Reduce code redundancy
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Dec 20, 2023
1 parent 68f2cbd commit efbd02e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 138 deletions.
68 changes: 4 additions & 64 deletions keras/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,42 +62,6 @@ def compute_loss_and_updates(
loss = self.optimizer.scale_loss(loss)
return loss, (unscaled_loss, y_pred, non_trainable_variables)

def _eager_build(self, data_batch):
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
compile_metrics_unbuilt = (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if model_unbuilt or compile_metrics_unbuilt:

def _convert_data_to_spec(d):
if d is None:
return None
return backend.KerasTensor(d.shape, d.dtype)

data_spec = tree.map_structure(_convert_data_to_spec, data_batch)
(
x_spec,
y_spec,
sample_weight_spec,
) = data_adapter_utils.unpack_x_y_sample_weight(data_spec)
# Note that this __call__ run the forward path and trigger variable
# creation.
y_pred_spec = backend.compute_output_spec(self.__call__, x_spec)
if compile_metrics_unbuilt:
# This will trigger the metric variable creation.
backend.compute_output_spec(
self.compute_metrics,
x_spec,
y_spec,
y_pred_spec,
sample_weight=sample_weight_spec,
)

if self.optimizer is not None and not self.optimizer.built:
# Build optimizer
self.optimizer.build(self.trainable_variables)

def train_step(self, state, data):
(
trainable_variables,
Expand Down Expand Up @@ -383,20 +347,7 @@ def fit(
steps_per_execution=self.steps_per_execution,
)

needs_building = (
not all(layer.built for layer in self._flatten_layers())
or not self.optimizer.built
or (
self._compile_metrics is not None
and not self._compile_metrics.built
)
)
if needs_building:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data_batch = data[0]
self._eager_build(data_batch)
break
self._symbolic_build(iterator=epoch_iterator)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
Expand Down Expand Up @@ -552,18 +503,7 @@ def evaluate(
steps_per_execution=self.steps_per_execution,
)

needs_building = not all(
layer.built for layer in self._flatten_layers()
) or (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if needs_building:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data_batch = data[0]
self._eager_build(data_batch)
break
self._symbolic_build(iterator=epoch_iterator)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
Expand Down Expand Up @@ -729,7 +669,7 @@ def train_on_batch(
data = self._distribute_data(data)

# Maybe build model
self._eager_build(data)
self._symbolic_build(data_batch=data)
self._record_training_state_sharding_spec()
self.make_train_function()

Expand Down Expand Up @@ -781,7 +721,7 @@ def test_on_batch(
data = (x, y, sample_weight)
data = self._distribute_data(data)
# Maybe build model
self._eager_build(data)
self._symbolic_build(data_batch=data)
self._record_training_state_sharding_spec()
self.make_test_function()

Expand Down
78 changes: 4 additions & 74 deletions keras/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from keras import backend
from keras import callbacks as callbacks_module
from keras import optimizers as optimizers_module
from keras.backend.common import standardize_dtype
from keras.backend.common.keras_tensor import KerasTensor
from keras.backend.torch.core import is_tensor
from keras.trainers import data_adapters
from keras.trainers import epoch_iterator
from keras.trainers import trainer as base_trainer
Expand Down Expand Up @@ -163,55 +160,6 @@ def one_step_on_data(data):
else:
self.predict_function = one_step_on_data

def _symbolic_build(self, data_batch):
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
compile_metrics_unbuilt = (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if model_unbuilt or compile_metrics_unbuilt:
# Create symbolic tensors matching an input batch.

def to_symbolic_input(v):
if is_tensor(v):
return KerasTensor(v.shape, standardize_dtype(v.dtype))
return v

data_batch = tree.map_structure(to_symbolic_input, data_batch)
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x)
except Exception as e:
raise RuntimeError(
"Unable to automatically build the model. "
"Please build it yourself before calling "
"fit/evaluate/predict. "
"A model is 'built' when its variables have "
"been created and its `self.built` attribute "
"is True. Usually, calling the model on a batch "
"of data is the right way to build it.\n"
"Exception encountered:\n"
f"'{e}'"
)
if compile_metrics_unbuilt:
# Build all metric state with `backend.compute_output_spec`.
backend.compute_output_spec(
self.compute_metrics,
x,
y,
y_pred,
sample_weight=sample_weight,
)
if self.optimizer is not None and not self.optimizer.built:
# Build optimizer
self.optimizer.build(self.trainable_variables)
self._post_build()

@traceback_utils.filter_traceback
def fit(
self,
Expand Down Expand Up @@ -270,20 +218,7 @@ def fit(
steps_per_execution=self.steps_per_execution,
)

needs_building = (
not all(layer.built for layer in self._flatten_layers())
or not self.optimizer.built
or (
self._compile_metrics is not None
and not self._compile_metrics.built
)
)
if needs_building:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch():
data_batch = data[0]
self._symbolic_build(data_batch)
break
self._symbolic_build(iterator=epoch_iterator)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
Expand Down Expand Up @@ -403,12 +338,7 @@ def evaluate(
steps_per_execution=self.steps_per_execution,
)

if not all(layer.built for layer in self._flatten_layers()):
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch():
data_batch = data[0]
self._symbolic_build(data_batch)
break
self._symbolic_build(iterator=epoch_iterator)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
Expand Down Expand Up @@ -525,7 +455,7 @@ def train_on_batch(
data = (x, y, sample_weight)

# Maybe build model
self._symbolic_build(data)
self._symbolic_build(data_batch=data)
self.make_train_function()

logs = self.train_function([data])
Expand All @@ -546,7 +476,7 @@ def test_on_batch(
data = (x, y, sample_weight)

# Maybe build model
self._symbolic_build(data)
self._symbolic_build(data_batch=data)
self.make_test_function()

logs = self.test_function([data])
Expand Down
63 changes: 63 additions & 0 deletions keras/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import platform
import warnings

import tree

from keras import backend
from keras import metrics as metrics_module
from keras import ops
Expand All @@ -9,6 +11,7 @@
from keras.saving import serialization_lib
from keras.trainers.compile_utils import CompileLoss
from keras.trainers.compile_utils import CompileMetrics
from keras.trainers.data_adapters import data_adapter_utils
from keras.utils import traceback_utils
from keras.utils import tracking

Expand Down Expand Up @@ -878,6 +881,66 @@ def _assert_compile_called(self, method_name=None):
msg += f"calling `{method_name}()`."
raise ValueError(msg)

def _symbolic_build(self, iterator=None, data_batch=None):
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
compile_metrics_unbuilt = (
self._compile_metrics is not None
and not self._compile_metrics.built
)
optimizer_unbuilt = (
self.optimizer is not None and not self.optimizer.built
)
if model_unbuilt or compile_metrics_unbuilt or optimizer_unbuilt:
if data_batch is None:
for _, data in iterator.enumerate_epoch():
data_batch = data[0]
break

if model_unbuilt or compile_metrics_unbuilt:
# Create symbolic tensors matching an input batch.

def to_symbolic_input(v):
if v is None:
return None
return backend.KerasTensor(
v.shape, backend.standardize_dtype(v.dtype)
)

data_batch = tree.map_structure(to_symbolic_input, data_batch)
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x)
except Exception as e:
raise RuntimeError(
"Unable to automatically build the model. "
"Please build it yourself before calling "
"fit/evaluate/predict. "
"A model is 'built' when its variables have "
"been created and its `self.built` attribute "
"is True. Usually, calling the model on a batch "
"of data is the right way to build it.\n"
"Exception encountered:\n"
f"'{e}'"
)
if compile_metrics_unbuilt:
# Build all metric state with `backend.compute_output_spec`.
backend.compute_output_spec(
self.compute_metrics,
x,
y,
y_pred,
sample_weight=sample_weight,
)
if optimizer_unbuilt:
# Build optimizer
self.optimizer.build(self.trainable_variables)
self._post_build()


def model_supports_jit(model):
# XLA not supported with TF on MacOS GPU
Expand Down

0 comments on commit efbd02e

Please sign in to comment.