Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 104 additions & 96 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
log = logging.getLogger(__name__)


class EarlyStop(Exception):
pass


def fit(
state: TrainState,
train_ds: InMemoryDataset,
Expand Down Expand Up @@ -113,124 +117,128 @@ def fit(
epoch_pbar = trange(
start_epoch, n_epochs, desc="Epochs", ncols=100, disable=disable_pbar, leave=True
)
with (
ocp.CheckpointManager(
latest_dir.resolve(), options=options
) as latest_ckpt_manager,
ocp.CheckpointManager(best_dir.resolve(), options=options) as best_ckpt_manager,
):
for epoch in range(start_epoch, n_epochs):
epoch_start_time = time.time()
callbacks.on_epoch_begin(epoch=epoch + 1)

if ema_handler:
ema_handler.update(state.params, epoch)

epoch_loss.update({"train_loss": 0.0})
train_batch_metrics = Metrics.empty()

batch_pbar = trange(
0,
train_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)

for batch_idx in range(train_steps_per_epoch):
callbacks.on_train_batch_begin(batch=batch_idx)

batch = next(batch_train_ds)
(
(state, train_batch_metrics),
batch_loss,
) = train_step(
(state, train_batch_metrics),
batch,
)

epoch_loss["train_loss"] += jnp.mean(batch_loss)
callbacks.on_train_batch_end(batch=batch_idx)
batch_pbar.update()

epoch_loss["train_loss"] /= train_steps_per_epoch
epoch_loss["train_loss"] = float(epoch_loss["train_loss"])

epoch_metrics = {
f"train_{key}": float(val)
for key, val in train_batch_metrics.compute().items()
}

if ema_handler:
ema_handler.update(state.params, epoch)
val_params = ema_handler.ema_params
else:
val_params = state.params

if val_ds is not None:
epoch_loss.update({"val_loss": 0.0})
val_batch_metrics = Metrics.empty()
try:
with (
ocp.CheckpointManager(
latest_dir.resolve(), options=options
) as latest_ckpt_manager,
ocp.CheckpointManager(
best_dir.resolve(), options=options
) as best_ckpt_manager,
):
for epoch in range(start_epoch, n_epochs):
epoch_start_time = time.time()
callbacks.on_epoch_begin(epoch=epoch + 1)

if ema_handler:
ema_handler.update(state.params, epoch)

epoch_loss.update({"train_loss": 0.0})
train_batch_metrics = Metrics.empty()

batch_pbar = trange(
0,
val_steps_per_epoch,
train_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)
for batch_idx in range(val_steps_per_epoch):
batch = next(batch_val_ds)

batch_loss, val_batch_metrics = val_step(
val_params, batch, val_batch_metrics
for batch_idx in range(train_steps_per_epoch):
callbacks.on_train_batch_begin(batch=batch_idx)

batch = next(batch_train_ds)
(
(state, train_batch_metrics),
batch_loss,
) = train_step(
(state, train_batch_metrics),
batch,
)
epoch_loss["val_loss"] += batch_loss

epoch_loss["train_loss"] += jnp.mean(batch_loss)
callbacks.on_train_batch_end(batch=batch_idx)
batch_pbar.update()

epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])
epoch_loss["train_loss"] /= train_steps_per_epoch
epoch_loss["train_loss"] = float(epoch_loss["train_loss"])

epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)
epoch_metrics = {
f"train_{key}": float(val)
for key, val in train_batch_metrics.compute().items()
}

epoch_metrics.update({**epoch_loss})
epoch_end_time = time.time()
epoch_metrics.update({"epoch_time": epoch_end_time - epoch_start_time})
if ema_handler:
ema_handler.update(state.params, epoch)
val_params = ema_handler.ema_params
else:
val_params = state.params

if val_ds is not None:
epoch_loss.update({"val_loss": 0.0})
val_batch_metrics = Metrics.empty()

batch_pbar = trange(
0,
val_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)
for batch_idx in range(val_steps_per_epoch):
batch = next(batch_val_ds)

batch_loss, val_batch_metrics = val_step(
val_params, batch, val_batch_metrics
)
epoch_loss["val_loss"] += batch_loss
batch_pbar.update()

epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)

ckpt = {"model": state, "epoch": epoch}
if epoch % ckpt_interval == 0:
latest_ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt))
epoch_metrics.update({**epoch_loss})
epoch_end_time = time.time()
epoch_metrics.update({"epoch_time": epoch_end_time - epoch_start_time})

if epoch_metrics["val_loss"] < best_loss:
best_ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt))
if abs(epoch_metrics["val_loss"] - best_loss) < patience_min_delta:
early_stopping_counter += 1
ckpt = {"model": state, "epoch": epoch}
if epoch % ckpt_interval == 0:
latest_ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt))

if epoch_metrics["val_loss"] < best_loss:
best_ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt))
if abs(epoch_metrics["val_loss"] - best_loss) < patience_min_delta:
early_stopping_counter += 1
else:
early_stopping_counter = 0

best_loss = epoch_metrics["val_loss"]
else:
early_stopping_counter = 0
early_stopping_counter += 1

best_loss = epoch_metrics["val_loss"]
else:
early_stopping_counter += 1
callbacks.on_epoch_end(epoch=epoch, logs=epoch_metrics)

callbacks.on_epoch_end(epoch=epoch, logs=epoch_metrics)
epoch_pbar.set_postfix(val_loss=epoch_metrics["val_loss"])
epoch_pbar.update()

epoch_pbar.set_postfix(val_loss=epoch_metrics["val_loss"])
epoch_pbar.update()
if patience is not None and early_stopping_counter >= patience:
raise EarlyStop()
except EarlyStop:
log.info(
f"Early stopping patience exceeded. Stopping training after {epoch} epochs."
)

if patience is not None and early_stopping_counter >= patience:
log.info(
"Early stopping patience exceeded. Stopping training after"
f" {epoch} epochs."
)
break
epoch_pbar.close()
callbacks.on_train_end()

Expand Down