Skip to content

Commit

Permalink
Add JAXEpochIterator and speed up get_numpy_iterator (#18991)
Browse files Browse the repository at this point in the history
* Add `JAXEpochIterator`

* Update `JAXEpochIterator`

* Update naming

* Add prefetching to `TorchEpochIterator` and replace `EpochIterator` with `TorchEpochIterator`

* Update `_get_iterator` to resolve "auto" as "np"

* Fix typo
  • Loading branch information
james77777778 authored Dec 26, 2023
1 parent 9723c79 commit df10cb2
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 35 deletions.
85 changes: 59 additions & 26 deletions keras/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections
import itertools
from functools import partial

import jax
Expand Down Expand Up @@ -336,7 +338,7 @@ def fit(
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)

# Create an iterator that yields batches for one epoch.
epoch_iterator = EpochIterator(
epoch_iterator = JAXEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
Expand Down Expand Up @@ -378,7 +380,7 @@ def fit(
metrics_variables = [v.value for v in self.metrics_variables]

self._purge_model_variables()
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
for step, data in epoch_iterator.enumerate_epoch():
# Callbacks
callbacks.on_train_batch_begin(step)

Expand All @@ -389,7 +391,6 @@ def fit(
optimizer_variables,
metrics_variables,
)
data = self._distribute_data(data)
logs, state = self.train_function(state, data)
(
trainable_variables,
Expand Down Expand Up @@ -426,9 +427,9 @@ def fit(

# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
# Create EpochIterator for evaluation and cache it.
# Create JAXEpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = EpochIterator(
self._eval_epoch_iterator = JAXEpochIterator(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
Expand Down Expand Up @@ -493,7 +494,7 @@ def evaluate(
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = EpochIterator(
epoch_iterator = JAXEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
Expand Down Expand Up @@ -531,15 +532,14 @@ def evaluate(
metrics_variables = [v.value for v in self.metrics_variables]

self._purge_model_variables(optimizer_variables=False)
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
for step, data in epoch_iterator.enumerate_epoch():
callbacks.on_test_batch_begin(step)

state = (
trainable_variables,
non_trainable_variables,
metrics_variables,
)
data = self._distribute_data(data)
logs, state = self.test_function(state, data)
(
trainable_variables,
Expand Down Expand Up @@ -579,7 +579,7 @@ def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
epoch_iterator = JAXEpochIterator(
x=x,
batch_size=batch_size,
steps_per_epoch=steps,
Expand All @@ -589,7 +589,7 @@ def predict(

if not all(layer.built for layer in self._flatten_layers()):
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
for _, data in epoch_iterator.enumerate_epoch():
# Build model
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data[0])
with backend.StatelessScope():
Expand Down Expand Up @@ -634,9 +634,8 @@ def append_to_outputs(batch_outputs, outputs):
]
state = (trainable_variables, non_trainable_variables)
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
for step, x in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(step)
x = self._distribute_data(x)
batch_outputs, state = self.predict_function(state, x)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
Expand Down Expand Up @@ -666,7 +665,7 @@ def train_on_batch(
y, class_weight
)
data = (x, y, sample_weight)
data = self._distribute_data(data)
data = _distribute_data(data)

# Maybe build model
self._symbolic_build(data_batch=data)
Expand Down Expand Up @@ -719,7 +718,7 @@ def test_on_batch(
self._assert_compile_called("test_on_batch")

data = (x, y, sample_weight)
data = self._distribute_data(data)
data = _distribute_data(data)
# Maybe build model
self._symbolic_build(data_batch=data)
self._record_training_state_sharding_spec()
Expand Down Expand Up @@ -795,18 +794,6 @@ def jax_state_sync(self):
for ref_v, v in zip(self.metrics_variables, metrics_variables):
ref_v.assign(v)

def _distribute_data(self, data):
distribution = distribution_lib.distribution()
if distribution is not None:

def distribute_single_value(d):
layout = distribution.get_data_layout(d.shape)
return jax_distribution_lib.distribute_data_input(d, layout)

return jax.tree_util.tree_map(distribute_single_value, data)
else:
return data

def _record_training_state_sharding_spec(self):
self._trainable_variable_shardings = [
v.value.sharding for v in self.trainable_variables
Expand Down Expand Up @@ -901,3 +888,49 @@ def _purge_model_variables(
if metric_variables:
for v in self.metrics_variables:
v._value = None


def _distribute_data(data):
distribution = distribution_lib.distribution()
if distribution is not None:

def distribute_single_value(d):
layout = distribution.get_data_layout(d.shape)
return jax_distribution_lib.distribute_data_input(d, layout)

return jax.tree_util.tree_map(distribute_single_value, data)
else:
return jax.tree_util.tree_map(jax.device_put, data)


class JAXEpochIterator(EpochIterator):
def _get_iterator(self, return_type="auto"):
if return_type in ("np", "auto"):
# enable prefetching when using numpy_iterator
return self._prefetch_numpy_iterator(super()._get_iterator("np"))
return super()._get_iterator(return_type)

def _prefetch_numpy_iterator(self, numpy_iterator):
"""Shard and prefetch batches on device.
Most of the implementation has been borrowed from
`flax.jax_utils.prefetch_to_device`
This utility takes an iterator and returns a new iterator which fills an
on device prefetch buffer. Eager prefetching can improve the performance
of training loops significantly by overlapping compute and data
transfer.
"""
queue = collections.deque()

# If you're training on GPUs, 2 is generally the best choice because
# this guarantees that you can overlap a training step on GPU with a
# data prefetch step on CPU.
def enqueue(n=2):
for data in itertools.islice(numpy_iterator, n):
queue.append(_distribute_data(data))

enqueue(n=2) # TODO: should we make `n` configurable?
while queue:
yield queue.popleft()
enqueue(1)
49 changes: 42 additions & 7 deletions keras/backend/torch/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections
import itertools
import warnings

import numpy as np
Expand All @@ -8,8 +10,8 @@
from keras import backend
from keras import callbacks as callbacks_module
from keras import optimizers as optimizers_module
from keras.backend.torch.core import get_device
from keras.trainers import data_adapters
from keras.trainers import epoch_iterator
from keras.trainers import trainer as base_trainer
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.epoch_iterator import EpochIterator
Expand Down Expand Up @@ -207,7 +209,7 @@ def fit(
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)

# Create an iterator that yields batches for one epoch.
epoch_iterator = EpochIterator(
epoch_iterator = TorchEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
Expand Down Expand Up @@ -263,9 +265,9 @@ def fit(

# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
# Create EpochIterator for evaluation and cache it.
# Create TorchEpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = EpochIterator(
self._eval_epoch_iterator = TorchEpochIterator(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
Expand Down Expand Up @@ -328,7 +330,7 @@ def evaluate(
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = EpochIterator(
epoch_iterator = TorchEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
Expand Down Expand Up @@ -494,11 +496,44 @@ def predict_on_batch(self, x):
return batch_outputs


class TorchEpochIterator(epoch_iterator.EpochIterator):
class TorchEpochIterator(EpochIterator):
def _get_iterator(self, return_type="auto"):
if return_type == "auto" and isinstance(
self.data_adapter, data_adapters.TorchDataLoaderAdapter
):
return self.data_adapter.get_torch_dataloader()

elif return_type in ("np", "auto"):
# enable prefetching when using numpy_iterator
return self._prefetch_numpy_iterator(super()._get_iterator("np"))
return super()._get_iterator(return_type)

def _prefetch_numpy_data(self, data):
def to_device(d):
return torch.as_tensor(d, device=get_device())

return tree.map_structure(to_device, data)

def _prefetch_numpy_iterator(self, numpy_iterator):
"""Prefetch batches on device.
The idea has been borrowed from
`torchtnt.utils.data.CudaDataPrefetcher`
This utility takes an iterator and returns a new iterator which fills an
on device prefetch buffer. Eager prefetching can improve the performance
of training loops significantly by overlapping compute and data
transfer.
"""
queue = collections.deque()

# If you're training on GPUs, 2 is generally the best choice because
# this guarantees that you can overlap a training step on GPU with a
# data prefetch step on CPU.
def enqueue(n=2):
for data in itertools.islice(numpy_iterator, n):
queue.append(self._prefetch_numpy_data(data))

enqueue(n=2) # TODO: should we make `n` configurable?
while queue:
yield queue.popleft()
enqueue(1)
4 changes: 3 additions & 1 deletion keras/trainers/data_adapters/tf_dataset_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import tree

from keras.trainers.data_adapters import data_adapter_utils
Expand Down Expand Up @@ -41,7 +42,8 @@ def get_numpy_iterator(self):
def convert_to_numpy(x):
if isinstance(x, tf.SparseTensor):
x = tf.sparse.to_dense(x)
return x.numpy()
# shared memory using `np.asarray`
return np.asarray(x)

for batch in self._dataset:
yield tree.map_structure(convert_to_numpy, batch)
Expand Down
6 changes: 5 additions & 1 deletion keras/trainers/data_adapters/torch_data_loader_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import tree

from keras.trainers.data_adapters.data_adapter import DataAdapter
Expand All @@ -22,7 +23,10 @@ def __init__(self, dataloader):

def get_numpy_iterator(self):
for batch in self._dataloader:
yield tuple(tree.map_structure(lambda x: x.cpu().numpy(), batch))
# shared memory using `np.asarray`
yield tuple(
tree.map_structure(lambda x: np.asarray(x.cpu()), batch)
)

def get_torch_dataloader(self):
return self._dataloader
Expand Down

0 comments on commit df10cb2

Please sign in to comment.