Skip to content

Commit

Permalink
(re)enable torch.compile in the pytorch trainer for train, predict, a…
Browse files Browse the repository at this point in the history
…nd eval
  • Loading branch information
kiukchung committed Oct 6, 2023
1 parent 5536dc5 commit 8212c7d
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 27 deletions.
4 changes: 4 additions & 0 deletions examples/keras_io/vision/mnist_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
"""

import numpy as np

import keras as keras
from keras import layers
from keras.utils.traceback_utils import disable_traceback_filtering

disable_traceback_filtering()

"""
## Prepare the data
Expand Down
9 changes: 8 additions & 1 deletion keras/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def convert_to_tensor(x, dtype=None, sparse=False):
return torch.as_tensor(x, dtype=torch.int32, device=get_device())
if isinstance(x, float):
return torch.as_tensor(x, dtype=torch.float32, device=get_device())

# Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)):
x = np.array(x)
Expand All @@ -172,6 +173,7 @@ def transform(x):
# Tensor has to be moved to CPU before converting to numpy.
if x.is_cuda or x.is_mps:
x = x.cpu()
return x.numpy()
return np.array(x)

if isinstance(x, (list, tuple)):
Expand All @@ -180,7 +182,12 @@ def transform(x):


def is_tensor(x):
return torch.is_tensor(x)
# Using the built-in `isinstance` is recommended by pytorch
# over using torch.is_tensor (see: https://pytorch.org/docs/stable/generated/torch.is_tensor.html)
# Also, `torch.is_tensor()` causes issues with dynamo caching when a torch.Tensor and numpy.ndarray
# of the same size, shape, and dtype is passed, if called on a Tensor first the second call with
# ndarray will return `True` and vice-versa.
return isinstance(x, torch.Tensor)


def shape(x):
Expand Down
25 changes: 17 additions & 8 deletions keras/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
import torch._dynamo.config as dynamo_config
import tree

from keras import backend
Expand All @@ -25,6 +26,13 @@ def __init__(self):
self.test_function = None
self.predict_function = None

# Ensures maximum compatibility when jit_compile=True
# by instructing dynamo to graph breaks and delegate to python
# for functions that cannot be traced.
# User can set this to False after instantiating the trainer
# to see trace errors and fix them for better performance.
dynamo_config.suppress_errors = True

def train_step(self, data):
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)

Expand Down Expand Up @@ -102,12 +110,7 @@ def one_step_on_data(data):
return self.train_step(data)

if self.jit_compile:
raise ValueError(
"`jit_compile` is not yet enabled for the PyTorch backend."
)
# Temporarily disabled torch compile due to failed unit tests.
# TODO: Uncomment the following line when unit tests passes.
# self.train_function = torch.compile(one_step_on_data)
self.train_function = torch.compile(one_step_on_data)
else:
self.train_function = one_step_on_data

Expand All @@ -127,7 +130,10 @@ def one_step_on_data(data):
with torch.no_grad():
return self.test_step(data)

self.test_function = one_step_on_data
if self.jit_compile:
self.test_function = torch.compile(one_step_on_data)
else:
self.test_function = one_step_on_data

def make_predict_function(self, force=False):
if self.predict_function is not None and not force:
Expand All @@ -145,7 +151,10 @@ def one_step_on_data(data):
with torch.no_grad():
return self.predict_step(data)

self.predict_function = one_step_on_data
if self.jit_compile:
self.predict_function = torch.compile(one_step_on_data)
else:
self.predict_function = one_step_on_data

def _symbolic_build(self, data_batch):
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
Expand Down
6 changes: 3 additions & 3 deletions keras/layers/reshaping/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def call(self, inputs):

def compute_output_shape(self, input_shape):
non_batch_dims = input_shape[1:]
if len(non_batch_dims) == 0:
flattened_dim = 1
elif None in non_batch_dims:
if None in non_batch_dims:
flattened_dim = None
elif len(non_batch_dims) == 0:
flattened_dim = 1
else:
flattened_dim = math.prod(non_batch_dims)
return (input_shape[0], flattened_dim)
Expand Down
2 changes: 1 addition & 1 deletion keras/trainers/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def enumerate_epoch(self, return_type="auto"):
if buffer:
yield step - len(buffer) + 1, buffer
if not self._num_batches:
# Infer the number of batches returned by the data_adater.
# Infer the number of batches returned by the data_adapter.
# Assumed static.
self._num_batches = step + 1
self.data_adapter.on_epoch_end()
Expand Down
25 changes: 11 additions & 14 deletions keras/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def compile(
and to set it to `True` when debugging.
steps_per_execution: Int. The number of batches to run
during each a single compiled function call. Running multiple
batches inside a single a single compiled function call can
batches inside a single compiled function call can
greatly improve performance on TPUs or small models with a large
Python overhead. At most, one full epoch will be run each
execution. If a number larger than the size of the epoch is
Expand All @@ -115,9 +115,11 @@ def compile(
each compiled function execution).
Not supported with the PyTorch backend.
jit_compile: Bool or `"auto"`. Whether to use XLA compilation when
compiling a model. Not supported with the PyTorch backend.
If `"auto"`, XLA compilation will be enabled if the
the model supports it, and disabled otherwise.
compiling a model. For `jax` and `tensorflow` backends,
`jit_compile="auto"` enables XLA compilation if the model supports it,
and disabled otherwise. For `torch` backend, `"auto"` will default
to eager execution and `jit_compile=True` will run with
`torch.compile` with the `"inductor"` backend.
auto_scale_loss: Bool. If `True` and the model dtype policy is
`"mixed_float16"`, the passed optimizer will be automatically
wrapped in a `LossScaleOptimizer`, which will dynamically
Expand Down Expand Up @@ -162,12 +164,7 @@ def compile(
"cannot also be True. Disabling `jit_compile`.",
stacklevel=2,
)
if jit_compile and backend.backend() == "torch":
warnings.warn(
"`jit_compile` is not yet enabled for the PyTorch backend. "
"Proceeding with `jit_compile=False`."
)
jit_compile = False

self.jit_compile = jit_compile
self.run_eagerly = run_eagerly
self.stop_training = False
Expand Down Expand Up @@ -866,11 +863,11 @@ def _assert_compile_called(self, method_name=None):


def resolve_auto_jit_compile(model):
if backend.backend() == "torch":
# jit_compile = "auto" with the pytorch backend defaults to eager
return False

if model_supports_jit(model):
if backend.backend() == "torch":
# Torch defaults to eager mode
# until torch compile is reliable
return False
return True
return False

Expand Down
20 changes: 20 additions & 0 deletions keras/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def call(self, x, training=False):


class TestTrainer(testing.TestCase, parameterized.TestCase):
def setUp(self):
if backend.backend() == "torch":
import torch._dynamo as dynamo

# reset dynamo cache (which is global) between test cases
# so that each test-case is independent
dynamo.reset()

@pytest.mark.requires_trainable_backend
def test_metric_tracking(self):
class ModelWithMetric(Trainer, layers.Dense):
Expand Down Expand Up @@ -148,6 +156,18 @@ def __init__(self, units):
)
self.assertEqual(len(model_weighted.metrics), 3)

@pytest.mark.skipif(
backend.backend() != "torch",
reason="torch backend runs in eager mode for jit_compile='auto'",
)
def test_compile_eager_vs_jit_torch(self):
model = ExampleModel(units=3)
model.compile(jit_compile="auto")
# torch trainer en/disables torch.compile only based on the value of
# model.jit_compile (not model.run_eagerly)
self.assertFalse(model.run_eagerly)
self.assertFalse(model.jit_compile)

@parameterized.named_parameters(
[
("eager", True, False, False),
Expand Down

0 comments on commit 8212c7d

Please sign in to comment.