Skip to content

Commit

Permalink
An improvement and updating an example
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 11, 2023
1 parent de6d0f9 commit 11ef922
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
25 changes: 17 additions & 8 deletions examples/mnist/mnist_save_resume_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
from ignite.handlers import Checkpoint, DiskSaver
from ignite.metrics import Accuracy, Loss
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.utils import manual_seed

try:
Expand Down Expand Up @@ -162,24 +162,26 @@ def run(
if deterministic:
tqdm.write("Setup deterministic trainer")
trainer = create_supervised_trainer(model, optimizer, criterion, device=device, deterministic=deterministic)
running_loss = RunningAverage(output_transform=lambda x: x)
running_loss.attach(trainer, "rloss")

evaluator = create_supervised_evaluator(
model, metrics={"accuracy": Accuracy(), "nll": Loss(criterion)}, device=device
)
metrics = {"accuracy": Accuracy(), "nll": Loss(criterion)}
evaluator = create_supervised_evaluator(model, metrics, device)

# Apply learning rate scheduling
@trainer.on(Events.EPOCH_COMPLETED)
def lr_step(engine):
lr_scheduler.step()

pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=f"Epoch {0} - loss: {0:.4f} - lr: {lr:.4f}")
pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=f"Epoch {0} - run. loss: {0:.4f} - lr: {lr:.4f}")

@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
lr = optimizer.param_groups[0]["lr"]
pbar.desc = f"Epoch {engine.state.epoch} - loss: {engine.state.output:.4f} - lr: {lr:.4f}"
rloss = engine.state.metrics["rloss"]
pbar.desc = f"Epoch {engine.state.epoch} - run. loss: {rloss:.4f} - lr: {lr:.4f}"
pbar.update(log_interval)
writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
writer.add_scalar("training/running_loss", rloss, engine.state.iteration)
writer.add_scalar("lr", lr, engine.state.iteration)

if crash_iteration > 0:
Expand Down Expand Up @@ -222,7 +224,14 @@ def log_validation_results(engine):
writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch)

# Setup object to checkpoint
objects_to_checkpoint = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
objects_to_checkpoint = {
"trainer": trainer,
"model": model,
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
"train_running_loss": running_loss,
"metrics": metrics,
}
training_checkpoint = Checkpoint(
to_save=objects_to_checkpoint,
save_handler=DiskSaver(log_dir, require_empty=False),
Expand Down
16 changes: 13 additions & 3 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,10 @@ def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
obj.consolidate_state_dict(to=self.save_on_rank)
if self.save_on_rank != idist.get_rank():
continue
checkpoint[k] = obj.state_dict()
if is_dict_of_serializable_objects(obj):
checkpoint[k] = {k: v.state_dict() for k, v in obj.items()}

Check warning on line 480 in ignite/handlers/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/checkpoint.py#L480

Added line #L480 was not covered by tests
else:
checkpoint[k] = obj.state_dict()
return checkpoint

@staticmethod
Expand Down Expand Up @@ -532,8 +535,8 @@ def setup_filename_pattern(

@staticmethod
def _check_objects(objs: Mapping, attr: str) -> None:
for k, obj in objs.items():
if not hasattr(obj, attr):
for _, obj in objs.items():
if not hasattr(obj, attr) and not is_dict_of_serializable_objects(obj):
raise TypeError(f"Object {type(obj)} should have `{attr}` method")

@staticmethod
Expand Down Expand Up @@ -611,6 +614,9 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None:
obj = obj.module
if isinstance(obj, torch.nn.Module):
obj.load_state_dict(chkpt_obj, strict=is_state_dict_strict)
elif is_dict_of_serializable_objects(obj):
for k, v in obj.items():
v.load_state_dict(chkpt_obj[k])

Check warning on line 619 in ignite/handlers/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

ignite/handlers/checkpoint.py#L618-L619

Added lines #L618 - L619 were not covered by tests
else:
obj.load_state_dict(chkpt_obj)

Expand Down Expand Up @@ -1002,3 +1008,7 @@ def __call__(self, engine: Engine, to_save: Mapping): # type: ignore
self._check_objects(to_save, "state_dict")
self.to_save = to_save
super(ModelCheckpoint, self).__call__(engine)


def is_dict_of_serializable_objects(obj: Any) -> bool:
return isinstance(obj, Mapping) and all([hasattr(v, "state_dict") for v in obj.values()])

0 comments on commit 11ef922

Please sign in to comment.