diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 550c600f..9edb5008 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -21,6 +21,10 @@ log = logging.getLogger(__name__) +class EarlyStop(Exception): + pass + + def fit( state: TrainState, train_ds: InMemoryDataset, @@ -113,124 +117,128 @@ def fit( epoch_pbar = trange( start_epoch, n_epochs, desc="Epochs", ncols=100, disable=disable_pbar, leave=True ) - with ( - ocp.CheckpointManager( - latest_dir.resolve(), options=options - ) as latest_ckpt_manager, - ocp.CheckpointManager(best_dir.resolve(), options=options) as best_ckpt_manager, - ): - for epoch in range(start_epoch, n_epochs): - epoch_start_time = time.time() - callbacks.on_epoch_begin(epoch=epoch + 1) - - if ema_handler: - ema_handler.update(state.params, epoch) - - epoch_loss.update({"train_loss": 0.0}) - train_batch_metrics = Metrics.empty() - - batch_pbar = trange( - 0, - train_steps_per_epoch, - desc="Batches", - ncols=100, - mininterval=1.0, - disable=disable_batch_pbar, - leave=False, - ) - - for batch_idx in range(train_steps_per_epoch): - callbacks.on_train_batch_begin(batch=batch_idx) - - batch = next(batch_train_ds) - ( - (state, train_batch_metrics), - batch_loss, - ) = train_step( - (state, train_batch_metrics), - batch, - ) - - epoch_loss["train_loss"] += jnp.mean(batch_loss) - callbacks.on_train_batch_end(batch=batch_idx) - batch_pbar.update() - - epoch_loss["train_loss"] /= train_steps_per_epoch - epoch_loss["train_loss"] = float(epoch_loss["train_loss"]) - - epoch_metrics = { - f"train_{key}": float(val) - for key, val in train_batch_metrics.compute().items() - } - - if ema_handler: - ema_handler.update(state.params, epoch) - val_params = ema_handler.ema_params - else: - val_params = state.params - - if val_ds is not None: - epoch_loss.update({"val_loss": 0.0}) - val_batch_metrics = Metrics.empty() + try: + with ( + ocp.CheckpointManager( + latest_dir.resolve(), options=options + ) as latest_ckpt_manager, + ocp.CheckpointManager( + best_dir.resolve(), options=options + ) as best_ckpt_manager, + ): + for epoch in range(start_epoch, n_epochs): + epoch_start_time = time.time() + callbacks.on_epoch_begin(epoch=epoch + 1) + + if ema_handler: + ema_handler.update(state.params, epoch) + + epoch_loss.update({"train_loss": 0.0}) + train_batch_metrics = Metrics.empty() batch_pbar = trange( 0, - val_steps_per_epoch, + train_steps_per_epoch, desc="Batches", ncols=100, mininterval=1.0, disable=disable_batch_pbar, leave=False, ) - for batch_idx in range(val_steps_per_epoch): - batch = next(batch_val_ds) - batch_loss, val_batch_metrics = val_step( - val_params, batch, val_batch_metrics + for batch_idx in range(train_steps_per_epoch): + callbacks.on_train_batch_begin(batch=batch_idx) + + batch = next(batch_train_ds) + ( + (state, train_batch_metrics), + batch_loss, + ) = train_step( + (state, train_batch_metrics), + batch, ) - epoch_loss["val_loss"] += batch_loss + + epoch_loss["train_loss"] += jnp.mean(batch_loss) + callbacks.on_train_batch_end(batch=batch_idx) batch_pbar.update() - epoch_loss["val_loss"] /= val_steps_per_epoch - epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) + epoch_loss["train_loss"] /= train_steps_per_epoch + epoch_loss["train_loss"] = float(epoch_loss["train_loss"]) - epoch_metrics.update( - { - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - } - ) + epoch_metrics = { + f"train_{key}": float(val) + for key, val in train_batch_metrics.compute().items() + } - epoch_metrics.update({**epoch_loss}) - epoch_end_time = time.time() - epoch_metrics.update({"epoch_time": epoch_end_time - epoch_start_time}) + if ema_handler: + ema_handler.update(state.params, epoch) + val_params = ema_handler.ema_params + else: + val_params = state.params + + if val_ds is not None: + epoch_loss.update({"val_loss": 0.0}) + val_batch_metrics = Metrics.empty() + + batch_pbar = trange( + 0, + val_steps_per_epoch, + desc="Batches", + ncols=100, + mininterval=1.0, + disable=disable_batch_pbar, + leave=False, + ) + for batch_idx in range(val_steps_per_epoch): + batch = next(batch_val_ds) + + batch_loss, val_batch_metrics = val_step( + val_params, batch, val_batch_metrics + ) + epoch_loss["val_loss"] += batch_loss + batch_pbar.update() + + epoch_loss["val_loss"] /= val_steps_per_epoch + epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) + + epoch_metrics.update( + { + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + } + ) - ckpt = {"model": state, "epoch": epoch} - if epoch % ckpt_interval == 0: - latest_ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt)) + epoch_metrics.update({**epoch_loss}) + epoch_end_time = time.time() + epoch_metrics.update({"epoch_time": epoch_end_time - epoch_start_time}) - if epoch_metrics["val_loss"] < best_loss: - best_ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt)) - if abs(epoch_metrics["val_loss"] - best_loss) < patience_min_delta: - early_stopping_counter += 1 + ckpt = {"model": state, "epoch": epoch} + if epoch % ckpt_interval == 0: + latest_ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt)) + + if epoch_metrics["val_loss"] < best_loss: + best_ckpt_manager.save(epoch, args=ocp.args.StandardSave(ckpt)) + if abs(epoch_metrics["val_loss"] - best_loss) < patience_min_delta: + early_stopping_counter += 1 + else: + early_stopping_counter = 0 + + best_loss = epoch_metrics["val_loss"] else: - early_stopping_counter = 0 + early_stopping_counter += 1 - best_loss = epoch_metrics["val_loss"] - else: - early_stopping_counter += 1 + callbacks.on_epoch_end(epoch=epoch, logs=epoch_metrics) - callbacks.on_epoch_end(epoch=epoch, logs=epoch_metrics) + epoch_pbar.set_postfix(val_loss=epoch_metrics["val_loss"]) + epoch_pbar.update() - epoch_pbar.set_postfix(val_loss=epoch_metrics["val_loss"]) - epoch_pbar.update() + if patience is not None and early_stopping_counter >= patience: + raise EarlyStop() + except EarlyStop: + log.info( + f"Early stopping patience exceeded. Stopping training after {epoch} epochs." + ) - if patience is not None and early_stopping_counter >= patience: - log.info( - "Early stopping patience exceeded. Stopping training after" - f" {epoch} epochs." - ) - break epoch_pbar.close() callbacks.on_train_end()