diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 359a39d0..746494ef 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -694,10 +694,14 @@ def run_symbolic_function(self, g: torch._C.Graph, n: torch._C.Node, sym_func: C if "module" in attrs: del attrs["module"] if pytorch_pfn_extras.requires("1.13"): + if pytorch_pfn_extras.requires("2.4.0.dev"): + g_ctx_kwargs: Dict[str, Any] = {"values_in_env": set()} + else: + g_ctx_kwargs = {} g_ctx = GraphContext( graph=g, block=n.owningBlock(), opset=self.opset_version, original_node=n, - params_dict=self.vars, env=self.torch2onnx_var) + params_dict=self.vars, env=self.torch2onnx_var, **g_ctx_kwargs) else: g_ctx = g # type: ignore if ( diff --git a/pytorch_pfn_extras/training/extensions/lr_scheduler.py b/pytorch_pfn_extras/training/extensions/lr_scheduler.py index 64e2ce30..ead7420b 100644 --- a/pytorch_pfn_extras/training/extensions/lr_scheduler.py +++ b/pytorch_pfn_extras/training/extensions/lr_scheduler.py @@ -1,10 +1,12 @@ from typing import Any, Dict, Optional +from pytorch_pfn_extras._torch_version import requires from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module from pytorch_pfn_extras.training._manager_protocol import ( ExtensionsManagerProtocol, ) +from torch.optim import Optimizer from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -33,6 +35,22 @@ def _default_stepper( scheduler.step() +def check_optimizer_is_called(optimizer: Optimizer) -> bool: + if requires("2.4.0.dev"): + # https://github.com/pytorch/pytorch/blob/afda6685ae87cce7ac2fe4bac3926572da2960f7/torch/optim/lr_scheduler.py#L172-L191 + # TODO: Rewrite this URL when pytorch 2.4.0 is released. + if hasattr(optimizer.step, "_wrapped_by_lr_sched"): + return getattr(optimizer, "_opt_called", False) + else: + return True + else: + # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/optim/lr_scheduler.py#L137-L138 + if hasattr(optimizer.step, "_with_counter"): + return bool(optimizer._step_count >= 1) # type: ignore[attr-defined] + else: + return True + + class LRScheduler(extension.Extension): """Trainer extension to adjust the learning rate using PyTorch's learning rate scheduler. @@ -72,8 +90,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/optim/lr_scheduler.py#L137-L138 if ( self.wait_for_first_optimizer_step - and hasattr(self.scheduler.optimizer.step, "_with_counter") - and self.scheduler.optimizer._step_count < 1 + and not check_optimizer_is_called(self.scheduler.optimizer) ): return self.stepper(manager, self.scheduler)