From cb746fda7b7650eafa5d1a2ef912d073ecc54a95 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 13 Apr 2023 17:56:54 +0800 Subject: [PATCH] ensure access state in runner Signed-off-by: Zhiyuan Chen --- danling/runner/base_runner.py | 86 ++++++++++++++++++---------------- danling/runner/runner_state.py | 6 +++ danling/runner/torch_runner.py | 4 +- 3 files changed, 54 insertions(+), 42 deletions(-) diff --git a/danling/runner/base_runner.py b/danling/runner/base_runner.py index b29a1212..72198645 100644 --- a/danling/runner/base_runner.py +++ b/danling/runner/base_runner.py @@ -33,14 +33,14 @@ class BaseRunner(RunnerBase): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.init_distributed() - if self.seed is not None: + if self.state.seed is not None: self.set_seed() - if self.deterministic: + if self.state.deterministic: self.set_deterministic() - if self.log: + if self.state.log: self.init_logging() self.init_print() - if self.tensorboard: + if self.state.tensorboard: self.init_tensorboard() @on_main_process @@ -68,7 +68,7 @@ def init_logging(self) -> None: "level": "DEBUG", "formatter": "standard", "class": "logging.FileHandler", - "filename": self.log_path, + "filename": self.state.log_path, "mode": "a", }, }, @@ -96,7 +96,7 @@ def init_print(self, process: int = 0) -> None: Notes ----- - If `self.log = True`, the default `print` function will be override by `logging.info`. + If `self.state.log = True`, the default `print` function will be override by `logging.info`. """ logger = logging.getLogger("print") @@ -108,7 +108,7 @@ def init_print(self, process: int = 0) -> None: @catch def print(*args, force=False, end="\n", file=None, flush=False, **kwargs): # pylint: disable=W0622 if self.rank == process or force: - if self.log: + if self.state.log: logger.info(*args, **kwargs) else: builtin_print(*args, end=end, file=file, flush=flush, **kwargs) @@ -128,7 +128,7 @@ def set_seed(self, seed: Optional[int] = None, bias: Optional[int] = None) -> No Args: seed: Random seed to set. - Defaults to `self.seed` (`config.seed`). + Defaults to `self.state.seed` (`config.seed`). bias: Make the seed different for each processes. @@ -140,7 +140,7 @@ def set_seed(self, seed: Optional[int] = None, bias: Optional[int] = None) -> No """ if seed is None: - seed = self.seed + seed = self.state.seed if bias is None: bias = self.rank if bias: @@ -188,7 +188,9 @@ def step(self, zero_grad: bool = True, batch_size: Optional[int] = None) -> None r""" Step optimizer and scheduler. - This method also increment the `self.steps` attribute. + This method increment `self.state.steps`. + + This method also increment `self.state.iters` when `batch_size` is specified. Args: zero_grad: Whether to zero the gradients. @@ -217,23 +219,27 @@ def state_dict(self, cls: Callable = dict) -> Mapping: @on_main_process def save_checkpoint(self) -> None: r""" - Save checkpoint to `runner.checkpoint_dir`. + Save checkpoint to `self.state.checkpoint_dir`. - The checkpoint will be saved to `runner.checkpoint_dir/latest.pth`. + The checkpoint will be saved to `self.state.checkpoint_dir/latest.pth`. - If `save_interval` is positive and `self.epochs + 1` is a multiple of `save_interval`, - the checkpoint will also be copied to `runner.checkpoint_dir/epoch-{self.epochs}.pth`. + If `self.state.save_interval` is positive and `self.state.epochs + 1` is a multiple of `save_interval`, + the checkpoint will also be copied to `self.state.checkpoint_dir/epoch-{self.state.epochs}.pth`. - If `self.is_best` is `True`, the checkpoint will also be copied to `runner.checkpoint_dir/best.pth`. + If `self.state.is_best` is `True`, the checkpoint will also be copied to `self.state.checkpoint_dir/best.pth`. """ - latest_path = os.path.join(self.checkpoint_dir, "latest.pth") + latest_path = os.path.join(self.state.checkpoint_dir, "latest.pth") self.save(self.state_dict(), latest_path) - if hasattr(self, "save_interval") and self.save_interval > 0 and (self.epochs + 1) % self.save_interval == 0: - save_path = os.path.join(self.checkpoint_dir, f"epoch-{self.epochs}.pth") + if ( + hasattr(self, "save_interval") + and self.save_interval > 0 + and (self.state.epochs + 1) % self.save_interval == 0 + ): + save_path = os.path.join(self.state.checkpoint_dir, f"epoch-{self.state.epochs}.pth") shutil.copy(latest_path, save_path) if self.is_best: - best_path = os.path.join(self.checkpoint_dir, "best.pth") + best_path = os.path.join(self.state.checkpoint_dir, "best.pth") shutil.copy(latest_path, best_path) def load_checkpoint( # pylint: disable=W1113 @@ -244,10 +250,10 @@ def load_checkpoint( # pylint: disable=W1113 Args: checkpoint: Checkpoint (or its path) to load. - Defaults to `runner.checkpoint_dir/latest.pth`. + Defaults to `self.state.checkpoint_dir/latest.pth`. override_config: If True, override runner config with checkpoint config. - *args: Additional arguments to pass to `runner.load`. - **kwargs: Additional keyword arguments to pass to `runner.load`. + *args: Additional arguments to pass to `self.load`. + **kwargs: Additional keyword arguments to pass to `self.load`. Raises: FileNotFoundError: If `checkpoint` does not exists. @@ -258,7 +264,7 @@ def load_checkpoint( # pylint: disable=W1113 """ if checkpoint is None: - checkpoint = os.path.join(self.checkpoint_dir, "latest.pth") # type: ignore + checkpoint = os.path.join(self.state.checkpoint_dir, "latest.pth") # type: ignore # TODO: Support loading checkpoints in other format if isinstance(checkpoint, str): if not os.path.exists(checkpoint): @@ -281,8 +287,8 @@ def load_pretrained(self, checkpoint: Union[Mapping, str], *args, **kwargs) -> N Args: checkpoint: Pretrained checkpoint (or its path) to load. - *args: Additional arguments to pass to `runner.load`. - **kwargs: Additional keyword arguments to pass to `runner.load`. + *args: Additional arguments to pass to `self.load`. + **kwargs: Additional keyword arguments to pass to `self.load`. Raises: FileNotFoundError: If `checkpoint` does not exists. @@ -309,9 +315,9 @@ def from_checkpoint(cls, checkpoint: Union[Mapping, str], *args, **kwargs) -> Ba Args: checkpoint: Checkpoint (or its path) to load. - Defaults to `runner.checkpoint_dir/latest.pth`. - *args: Additional arguments to pass to `runner.load`. - **kwargs: Additional keyword arguments to pass to `runner.load`. + Defaults to `self.state.checkpoint_dir/latest.pth`. + *args: Additional arguments to pass to `self.load`. + **kwargs: Additional keyword arguments to pass to `self.load`. Returns: (BaseRunner): @@ -325,39 +331,39 @@ def from_checkpoint(cls, checkpoint: Union[Mapping, str], *args, **kwargs) -> Ba def append_result(self, result) -> None: r""" - Append result to `self.results`. + Append result to `self.state.results`. Warnings: - `self.results` is heavily relied upon for computing metrics. + `self.state.results` is heavily relied upon for computing metrics. Failed to use this method may lead to unexpected behavior. """ - self.results.append(result) + self.state.results.append(result) def print_result(self) -> None: r""" Print latest and best result. """ - print(f"results: {self.results}") - print(f"latest result: {self.latest_result}") - print(f"best result: {self.best_result}") + print(f"results: {self.state.results}") + print(f"latest result: {self.state.latest_result}") + print(f"best result: {self.state.best_result}") @catch @on_main_process def save_result(self) -> None: r""" - Save result to `runner.dir`. + Save result to `self.state.dir`. This method will save latest and best result to - `runner.dir/latest.json` and `runner.dir/best.json` respectively. + `self.state.dir/latest.json` and `self.state.dir/best.json` respectively. """ - results_path = os.path.join(self.dir, "results.json") - self.save({"id": self.id, "name": self.name, "results": self.results}, results_path, indent=4) - ret = {"id": self.id, "name": self.name} - result = self.latest_result # type: ignore + results_path = os.path.join(self.state.dir, "results.json") + self.save({"id": self.state.id, "name": self.state.name, "results": self.state.results}, results_path, indent=4) + ret = {"id": self.state.id, "name": self.state.name} + result = self.state.latest_result # type: ignore if isinstance(result, FlatDict): result = result.dict() # type: ignore # This is slower but ensure id is the first key diff --git a/danling/runner/runner_state.py b/danling/runner/runner_state.py index 2e135d26..b4020539 100644 --- a/danling/runner/runner_state.py +++ b/danling/runner/runner_state.py @@ -132,6 +132,10 @@ class RunnerState(NestedDict): Defaults to `True`. tensorboard (bool): Whether to use `tensorboard`. Defaults to `False`. + print_interval (int): Interval of printing logs. + Defaults to -1. + save_interval (int): Interval of saving intermediate checkpoints. + Defaults to -1, never save intermediate checkpoints. Notes: `RunnerState` is a `NestedDict`, so you can access its attributes by `state["name"]` or `state.name`. @@ -172,6 +176,8 @@ class RunnerState(NestedDict): checkpoint_dir_name: str = "checkpoints" log: bool = True tensorboard: bool = False + print_interval: int = -1 + save_interval: int = -1 def __init__(self, *args, **kwargs): if Repo is not None: diff --git a/danling/runner/torch_runner.py b/danling/runner/torch_runner.py index 77ba1a76..d2fa72e0 100644 --- a/danling/runner/torch_runner.py +++ b/danling/runner/torch_runner.py @@ -72,7 +72,7 @@ def set_seed(self, seed: int = None, bias: Optional[int] = None) -> None: # typ Args: seed: Random seed to set. - Defaults to `self.seed` (`config.seed`). + Defaults to `self.state.seed` (`config.seed`). bias: Make the seed different for each processes. @@ -84,7 +84,7 @@ def set_seed(self, seed: int = None, bias: Optional[int] = None) -> None: # typ """ if seed is None: - seed = self.seed + seed = self.state.seed if self.distributed: object_list = [seed] dist.broadcast_object_list(object_list)