Skip to content

Commit

Permalink
Merge branch 'keras-team:main' into improve-dtype-in-ops
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Sep 21, 2023
2 parents 6841f50 + 5c11fe6 commit 77fca18
Show file tree
Hide file tree
Showing 7 changed files with 666 additions and 34 deletions.
102 changes: 92 additions & 10 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import builtins
import functools
import math
import warnings

import tensorflow as tf
from tensorflow.experimental import numpy as tfnp
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops

from keras_core.backend import config
from keras_core.backend.tensorflow.core import convert_to_tensor
Expand Down Expand Up @@ -52,22 +54,82 @@ def subtract(x1, x2):


def matmul(x1, x2):
if isinstance(x1, tf.SparseTensor):
def with_combined_batch_dimensions(a, b, fn_3d):
batch_shape = (
b.shape[:-2] if isinstance(b, tf.SparseTensor) else a.shape[:-2]
)
batch_size = math.prod(batch_shape)
a_3d = reshape(a, [batch_size] + a.shape[-2:])
b_3d = reshape(b, [batch_size] + b.shape[-2:])
result = fn_3d(a_3d, b_3d)
return reshape(result, batch_shape + result.shape[1:])

def sparse_sparse_matmul(a, b):
dtype = a.values.dtype
# Convert SparseTensors to CSR SparseMatrix.
a_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
a.indices, a.values, a.dense_shape
)
b_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
b.indices, b.values, b.dense_shape
)
# Compute the CSR SparseMatrix matrix multiplication.
result_csr = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
a_csr, b_csr, dtype
)
# Convert the CSR SparseMatrix to a SparseTensor.
res = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
result_csr, dtype
)
return tf.SparseTensor(res.indices, res.values, res.dense_shape)

def embedding_lookup_sparse_dense_matmul(a, b):
# We need at least one id per rows for embedding_lookup_sparse,
# otherwise there will be missing rows in the output.
x1, _ = tf.sparse.fill_empty_rows(x1, 0)
a, _ = tf.sparse.fill_empty_rows(a, 0)
# We need to split x1 into separate ids and weights tensors. The ids
# should be the column indices of x1 and the values of the weights
# can continue to be the actual x1. The column arrangement of ids and
# weights does not matter as we sum over columns. See documentation for
# sparse_ops.sparse_tensor_dense_matmul for details.
# can continue to be the actual x1. The column arrangement of ids
# and weights does not matter as we sum over columns. See details in
# the documentation for sparse_ops.sparse_tensor_dense_matmul.
ids = tf.SparseTensor(
indices=x1.indices,
values=x1.indices[:, 1],
dense_shape=x1.dense_shape,
indices=a.indices,
values=a.indices[:, 1],
dense_shape=a.dense_shape,
)
weights = x1
return tf.nn.embedding_lookup_sparse(x2, ids, weights, combiner="sum")
return tf.nn.embedding_lookup_sparse(b, ids, a, combiner="sum")

# Either a or b is sparse
def sparse_dense_matmul_3d(a, b):
return tf.map_fn(
lambda x: tf.sparse.sparse_dense_matmul(x[0], x[1]),
elems=(a, b),
fn_output_signature=a.dtype,
)

x1_sparse = isinstance(x1, tf.SparseTensor)
x2_sparse = isinstance(x2, tf.SparseTensor)
if x1_sparse and x2_sparse:
if x1.shape.rank <= 3:
return sparse_sparse_matmul(x1, x2)
else:
return with_combined_batch_dimensions(x1, x2, sparse_sparse_matmul)
elif x1_sparse or x2_sparse:
# Sparse * dense or dense * sparse
sparse_rank = x1.shape.rank if x1_sparse else x2.shape.rank

# Special case: embedding_lookup_sparse for sparse * dense and rank 2
if x1_sparse and sparse_rank == 2:
return embedding_lookup_sparse_dense_matmul(x1, x2)
elif sparse_rank == 2:
return tf.sparse.sparse_dense_matmul(x1, x2)
elif sparse_rank == 3:
return sparse_dense_matmul_3d(x1, x2)
else:
return with_combined_batch_dimensions(
x1, x2, sparse_dense_matmul_3d
)

return tfnp.matmul(x1, x2)


Expand Down Expand Up @@ -357,6 +419,8 @@ def exp(x):


def expand_dims(x, axis):
if isinstance(x, tf.SparseTensor):
return tf.sparse.expand_dims(x, axis)
return tfnp.expand_dims(x, axis)


Expand Down Expand Up @@ -772,6 +836,24 @@ def sqrt(x):


def squeeze(x, axis=None):
if isinstance(x, tf.SparseTensor):
new_shape = list(x.shape)
gather_indices = list(range(len(new_shape)))
if axis is None:
for i in range(len(new_shape) - 1, -1, -1):
if new_shape[i] == 1:
del new_shape[i]
del gather_indices[i]
else:
if new_shape[axis] != 1:
raise ValueError(
f"Cannot squeeze axis {axis}, because the "
"dimension is not 1."
)
del new_shape[axis]
del gather_indices[axis]
new_indices = tf.gather(x.indices, gather_indices, axis=1)
return tf.SparseTensor(new_indices, x.values, tuple(new_shape))
return tfnp.squeeze(x, axis=axis)


Expand Down
1 change: 1 addition & 0 deletions keras_core/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras_core.callbacks.backup_and_restore_callback import BackupAndRestore
from keras_core.callbacks.callback import Callback
from keras_core.callbacks.callback_list import CallbackList
from keras_core.callbacks.csv_logger import CSVLogger
Expand Down
216 changes: 216 additions & 0 deletions keras_core/callbacks/backup_and_restore_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import os
import warnings

from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import file_utils


@keras_core_export("keras_core.callbacks.BackupAndRestore")
class BackupAndRestore(Callback):
"""Callback to back up and restore the training state.
`BackupAndRestore` callback is intended to recover training from an
interruption that has happened in the middle of a `Model.fit` execution, by
backing up the training states in a temporary checkpoint file, at the end of
each epoch. Each backup overwrites the previously written checkpoint file,
so at any given time there is at most one such checkpoint file for
backup/restoring purpose.
If training restarts before completion, the training state (which includes
the `Model` weights and epoch number) is restored to the most recently saved
state at the beginning of a new `Model.fit` run. At the completion of a
`Model.fit` run, the temporary checkpoint file is deleted.
Note that the user is responsible to bring jobs back after the interruption.
This callback is important for the backup and restore mechanism for fault
tolerance purpose, and the model to be restored from a previous checkpoint
is expected to be the same as the one used to back up. If user changes
arguments passed to compile or fit, the checkpoint saved for fault tolerance
can become invalid.
Example:
>>> class InterruptingCallback(keras.callbacks.Callback):
... def on_epoch_begin(self, epoch, logs=None):
... if epoch == 4:
... raise RuntimeError('Interrupting!')
>>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup")
>>> model = keras.models.Sequential([keras.layers.Dense(10)])
>>> model.compile(keras.optimizers.SGD(), loss='mse')
>>> try:
... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
... batch_size=1, callbacks=[callback, InterruptingCallback()],
... verbose=0)
... except:
... pass
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
... epochs=10, batch_size=1, callbacks=[callback],
... verbose=0)
>>> # Only 6 more epochs are run, since first training got interrupted at
>>> # zero-indexed epoch 4, second training will continue from 4 to 9.
>>> len(history.history['loss'])
>>> 6
Args:
file_path: String, path to store the checkpoint.
e.g. `backup_dir = os.path.join(working_dir, "backup")`.
This is the directory in which the system stores temporary files to
recover the model from jobs terminated unexpectedly. The directory
cannot be reused elsewhere to store other files, e.g. by the
`BackupAndRestore` callback of another training run,
or by another callback
(e.g. `ModelCheckpoint`) of the same training.
save_freq: `"epoch"`, integer, or `False`. When set to `"epoch"`
the callback saves the checkpoint at the end of each epoch.
When set to an integer, the callback saves the checkpoint every
`save_freq` batches. Set `save_freq` to `False` if only using
preemption checkpointing (with `save_before_preemption=True`).
delete_checkpoint: Boolean, default to True. This `BackupAndRestore`
callback works by saving a checkpoint to back up the training state.
If `delete_checkpoint=True`, the checkpoint will be deleted after
training is finished. Use `False` if you'd like to keep the checkpoint
for future usage.
save_before_preemption: A boolean value instructing whether to turn on
the automatic checkpoint saving for preemption/maintenance events.
"""

def __init__(
self,
file_path,
save_freq="epoch",
delete_checkpoint=True,
save_before_preemption=False,
):
super().__init__()
self._current_epoch = 0
self.save_freq = save_freq
self.delete_checkpoint = delete_checkpoint
self.save_before_preemption = save_before_preemption
self._batches_seen_since_last_saving = 0
self._last_batch_seen = 0

if not file_path:
raise ValueError("Empty `backup_dir` argument passed")
self.file_path = file_path

if not save_freq and not save_before_preemption:
raise ValueError(
"Either `save_freq` or `save_before_preemption` " "must be set."
)

if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
raise ValueError(
f"Unrecognized save_freq: {self.save_freq}. "
"Expected save_freq are 'epoch' or integer values"
)
if self.save_before_preemption:
warnings.warn("`save_before_preemption` not yet implemented")

def on_train_begin(self, logs=None):
"""
Get training state from temporary file and restore it
"""
if self._check_checkpoints_exists(self.file_path):
self._model.load_weights(filepath=self.file_path)

def on_train_end(self, logs=None):
if self.delete_checkpoint and self._check_checkpoints_exists(
self.file_path
):
self._cleanup_checkpoint()

def on_epoch_begin(self, epoch, logs=None):
self._current_epoch = epoch

def on_epoch_end(self, epoch, logs=None):
if self.save_freq == "epoch":
self._save_model(epoch=epoch, batch=None, logs=logs)

def on_train_batch_end(self, batch, logs=None):
if self._should_save_on_batch(batch):
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)

def _save_model(self, epoch, batch, logs):
"""Saves the model.
Args:
epoch: the epoch this iteration is in.
batch: the batch this iteration is in. `None` if the `save_freq`
is set to `"epoch"`.
logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
"""
logs = logs or {}

filepath = self._get_file_path(epoch, batch, logs)
# Create host directory if it doesn't exist.
dirname = os.path.dirname(filepath)
if dirname and not file_utils.exists(dirname):
file_utils.makedirs(dirname)

try:
self._model.save_weights(filepath=filepath, overwrite=True)
except IsADirectoryError: # h5py 3.x
raise IOError(
"Please specify a non-directory filepath for "
"ModelCheckpoint. Filepath used is an existing "
f"directory: {filepath}"
)
except IOError as e: # h5py 2.x
# `e.errno` appears to be `None` so checking the content of
# `e.args[0]`.
if "is a directory" in str(e.args[0]).lower():
raise IOError(
"Please specify a non-directory filepath for "
"ModelCheckpoint. Filepath used is an existing "
f"directory: f{filepath}"
)
# Re-throw the error for any other causes.
raise e

def _get_file_path(self, epoch, batch, logs):
"""Returns the file path for checkpoint."""

try:
# `filepath` may contain placeholders such as
# `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between
# logged metrics and the path's placeholders can cause formatting to
# fail.
if batch is None or "batch" in logs:
file_path = self.file_path.format(epoch=epoch + 1, **logs)
else:
file_path = self.file_path.format(
epoch=epoch + 1, batch=batch + 1, **logs
)
except KeyError as e:
raise KeyError(
f'Failed to format this callback filepath: "{self.file_path}". '
f"Reason: {e}"
)
return file_path

def _should_save_on_batch(self, batch):
"""Handles batch-level saving logic, supports steps_per_execution."""
if self.save_freq == "epoch":
return False
if batch <= self._last_batch_seen: # New epoch.
add_batches = batch + 1 # batches are zero-indexed.
else:
add_batches = batch - self._last_batch_seen
self._batches_seen_since_last_saving += add_batches
self._last_batch_seen = batch

if self._batches_seen_since_last_saving >= self.save_freq:
self._batches_seen_since_last_saving = 0
return True
return False

def _cleanup_checkpoint(self):
"""
Delete other checkpoint files (if present) in the directory
"""
if self._check_checkpoints_exists(filepath=self.file_path):
file_utils.rmtree(self.file_path)

def _check_checkpoints_exists(self, filepath):
return file_utils.exists(filepath)
Loading

0 comments on commit 77fca18

Please sign in to comment.