Skip to content

Commit

Permalink
ensure access state in runner
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 13, 2023
1 parent 725dd52 commit cb746fd
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 42 deletions.
86 changes: 46 additions & 40 deletions danling/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ class BaseRunner(RunnerBase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.init_distributed()
if self.seed is not None:
if self.state.seed is not None:
self.set_seed()
if self.deterministic:
if self.state.deterministic:
self.set_deterministic()
if self.log:
if self.state.log:
self.init_logging()
self.init_print()
if self.tensorboard:
if self.state.tensorboard:
self.init_tensorboard()

@on_main_process
Expand Down Expand Up @@ -68,7 +68,7 @@ def init_logging(self) -> None:
"level": "DEBUG",
"formatter": "standard",
"class": "logging.FileHandler",
"filename": self.log_path,
"filename": self.state.log_path,
"mode": "a",
},
},
Expand Down Expand Up @@ -96,7 +96,7 @@ def init_print(self, process: int = 0) -> None:
Notes
-----
If `self.log = True`, the default `print` function will be override by `logging.info`.
If `self.state.log = True`, the default `print` function will be override by `logging.info`.
"""

logger = logging.getLogger("print")
Expand All @@ -108,7 +108,7 @@ def init_print(self, process: int = 0) -> None:
@catch
def print(*args, force=False, end="\n", file=None, flush=False, **kwargs): # pylint: disable=W0622
if self.rank == process or force:
if self.log:
if self.state.log:
logger.info(*args, **kwargs)
else:
builtin_print(*args, end=end, file=file, flush=flush, **kwargs)
Expand All @@ -128,7 +128,7 @@ def set_seed(self, seed: Optional[int] = None, bias: Optional[int] = None) -> No
Args:
seed: Random seed to set.
Defaults to `self.seed` (`config.seed`).
Defaults to `self.state.seed` (`config.seed`).
bias: Make the seed different for each processes.
Expand All @@ -140,7 +140,7 @@ def set_seed(self, seed: Optional[int] = None, bias: Optional[int] = None) -> No
"""

if seed is None:
seed = self.seed
seed = self.state.seed
if bias is None:
bias = self.rank
if bias:
Expand Down Expand Up @@ -188,7 +188,9 @@ def step(self, zero_grad: bool = True, batch_size: Optional[int] = None) -> None
r"""
Step optimizer and scheduler.
This method also increment the `self.steps` attribute.
This method increment `self.state.steps`.
This method also increment `self.state.iters` when `batch_size` is specified.
Args:
zero_grad: Whether to zero the gradients.
Expand Down Expand Up @@ -217,23 +219,27 @@ def state_dict(self, cls: Callable = dict) -> Mapping:
@on_main_process
def save_checkpoint(self) -> None:
r"""
Save checkpoint to `runner.checkpoint_dir`.
Save checkpoint to `self.state.checkpoint_dir`.
The checkpoint will be saved to `runner.checkpoint_dir/latest.pth`.
The checkpoint will be saved to `self.state.checkpoint_dir/latest.pth`.
If `save_interval` is positive and `self.epochs + 1` is a multiple of `save_interval`,
the checkpoint will also be copied to `runner.checkpoint_dir/epoch-{self.epochs}.pth`.
If `self.state.save_interval` is positive and `self.state.epochs + 1` is a multiple of `save_interval`,
the checkpoint will also be copied to `self.state.checkpoint_dir/epoch-{self.state.epochs}.pth`.
If `self.is_best` is `True`, the checkpoint will also be copied to `runner.checkpoint_dir/best.pth`.
If `self.state.is_best` is `True`, the checkpoint will also be copied to `self.state.checkpoint_dir/best.pth`.
"""

latest_path = os.path.join(self.checkpoint_dir, "latest.pth")
latest_path = os.path.join(self.state.checkpoint_dir, "latest.pth")
self.save(self.state_dict(), latest_path)
if hasattr(self, "save_interval") and self.save_interval > 0 and (self.epochs + 1) % self.save_interval == 0:
save_path = os.path.join(self.checkpoint_dir, f"epoch-{self.epochs}.pth")
if (
hasattr(self, "save_interval")
and self.save_interval > 0
and (self.state.epochs + 1) % self.save_interval == 0
):
save_path = os.path.join(self.state.checkpoint_dir, f"epoch-{self.state.epochs}.pth")
shutil.copy(latest_path, save_path)
if self.is_best:
best_path = os.path.join(self.checkpoint_dir, "best.pth")
best_path = os.path.join(self.state.checkpoint_dir, "best.pth")
shutil.copy(latest_path, best_path)

def load_checkpoint( # pylint: disable=W1113
Expand All @@ -244,10 +250,10 @@ def load_checkpoint( # pylint: disable=W1113
Args:
checkpoint: Checkpoint (or its path) to load.
Defaults to `runner.checkpoint_dir/latest.pth`.
Defaults to `self.state.checkpoint_dir/latest.pth`.
override_config: If True, override runner config with checkpoint config.
*args: Additional arguments to pass to `runner.load`.
**kwargs: Additional keyword arguments to pass to `runner.load`.
*args: Additional arguments to pass to `self.load`.
**kwargs: Additional keyword arguments to pass to `self.load`.
Raises:
FileNotFoundError: If `checkpoint` does not exists.
Expand All @@ -258,7 +264,7 @@ def load_checkpoint( # pylint: disable=W1113
"""

if checkpoint is None:
checkpoint = os.path.join(self.checkpoint_dir, "latest.pth") # type: ignore
checkpoint = os.path.join(self.state.checkpoint_dir, "latest.pth") # type: ignore
# TODO: Support loading checkpoints in other format
if isinstance(checkpoint, str):
if not os.path.exists(checkpoint):
Expand All @@ -281,8 +287,8 @@ def load_pretrained(self, checkpoint: Union[Mapping, str], *args, **kwargs) -> N
Args:
checkpoint: Pretrained checkpoint (or its path) to load.
*args: Additional arguments to pass to `runner.load`.
**kwargs: Additional keyword arguments to pass to `runner.load`.
*args: Additional arguments to pass to `self.load`.
**kwargs: Additional keyword arguments to pass to `self.load`.
Raises:
FileNotFoundError: If `checkpoint` does not exists.
Expand All @@ -309,9 +315,9 @@ def from_checkpoint(cls, checkpoint: Union[Mapping, str], *args, **kwargs) -> Ba
Args:
checkpoint: Checkpoint (or its path) to load.
Defaults to `runner.checkpoint_dir/latest.pth`.
*args: Additional arguments to pass to `runner.load`.
**kwargs: Additional keyword arguments to pass to `runner.load`.
Defaults to `self.state.checkpoint_dir/latest.pth`.
*args: Additional arguments to pass to `self.load`.
**kwargs: Additional keyword arguments to pass to `self.load`.
Returns:
(BaseRunner):
Expand All @@ -325,39 +331,39 @@ def from_checkpoint(cls, checkpoint: Union[Mapping, str], *args, **kwargs) -> Ba

def append_result(self, result) -> None:
r"""
Append result to `self.results`.
Append result to `self.state.results`.
Warnings:
`self.results` is heavily relied upon for computing metrics.
`self.state.results` is heavily relied upon for computing metrics.
Failed to use this method may lead to unexpected behavior.
"""

self.results.append(result)
self.state.results.append(result)

def print_result(self) -> None:
r"""
Print latest and best result.
"""

print(f"results: {self.results}")
print(f"latest result: {self.latest_result}")
print(f"best result: {self.best_result}")
print(f"results: {self.state.results}")
print(f"latest result: {self.state.latest_result}")
print(f"best result: {self.state.best_result}")

@catch
@on_main_process
def save_result(self) -> None:
r"""
Save result to `runner.dir`.
Save result to `self.state.dir`.
This method will save latest and best result to
`runner.dir/latest.json` and `runner.dir/best.json` respectively.
`self.state.dir/latest.json` and `self.state.dir/best.json` respectively.
"""

results_path = os.path.join(self.dir, "results.json")
self.save({"id": self.id, "name": self.name, "results": self.results}, results_path, indent=4)
ret = {"id": self.id, "name": self.name}
result = self.latest_result # type: ignore
results_path = os.path.join(self.state.dir, "results.json")
self.save({"id": self.state.id, "name": self.state.name, "results": self.state.results}, results_path, indent=4)
ret = {"id": self.state.id, "name": self.state.name}
result = self.state.latest_result # type: ignore
if isinstance(result, FlatDict):
result = result.dict() # type: ignore
# This is slower but ensure id is the first key
Expand Down
6 changes: 6 additions & 0 deletions danling/runner/runner_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ class RunnerState(NestedDict):
Defaults to `True`.
tensorboard (bool): Whether to use `tensorboard`.
Defaults to `False`.
print_interval (int): Interval of printing logs.
Defaults to -1.
save_interval (int): Interval of saving intermediate checkpoints.
Defaults to -1, never save intermediate checkpoints.
Notes:
`RunnerState` is a `NestedDict`, so you can access its attributes by `state["name"]` or `state.name`.
Expand Down Expand Up @@ -172,6 +176,8 @@ class RunnerState(NestedDict):
checkpoint_dir_name: str = "checkpoints"
log: bool = True
tensorboard: bool = False
print_interval: int = -1
save_interval: int = -1

def __init__(self, *args, **kwargs):
if Repo is not None:
Expand Down
4 changes: 2 additions & 2 deletions danling/runner/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def set_seed(self, seed: int = None, bias: Optional[int] = None) -> None: # typ
Args:
seed: Random seed to set.
Defaults to `self.seed` (`config.seed`).
Defaults to `self.state.seed` (`config.seed`).
bias: Make the seed different for each processes.
Expand All @@ -84,7 +84,7 @@ def set_seed(self, seed: int = None, bias: Optional[int] = None) -> None: # typ
"""

if seed is None:
seed = self.seed
seed = self.state.seed
if self.distributed:
object_list = [seed]
dist.broadcast_object_list(object_list)
Expand Down

0 comments on commit cb746fd

Please sign in to comment.