diff --git a/pprobe/bootstrap/hook_setup.py b/pprobe/bootstrap/hook_setup.py index 2029964..ae3ddf5 100644 --- a/pprobe/bootstrap/hook_setup.py +++ b/pprobe/bootstrap/hook_setup.py @@ -241,12 +241,6 @@ def run_torch_catch_step_hook(self): def run_torch_catch_lr_hook(self): Logger.info(f"[PPROBE] torch catch lr hook executed") - self.module.optim.lr_scheduler.LRScheduler.step = ( - lr_scheduler_step_method_wrapper( - self.module.optim.lr_scheduler.LRScheduler.step - ) - ) - self.module.optim.Optimizer.zero_grad = optimizer_zero_grad_method_wrapper( self.module.optim.Optimizer.zero_grad ) diff --git a/pprobe/bootstrap/hooks/pytorch_optim.py b/pprobe/bootstrap/hooks/pytorch_optim.py index 0beb3a2..65b4ee5 100644 --- a/pprobe/bootstrap/hooks/pytorch_optim.py +++ b/pprobe/bootstrap/hooks/pytorch_optim.py @@ -2,22 +2,15 @@ from typing import Any, Callable, Optional -def lr_scheduler_step_method_wrapper( - original_method: Callable, epoch: Optional[int] = None -): - def wrapper(self) -> Any: - lr = self.get_lr() - Logger.info(f"[PPROBE] INIT LR ===> {lr}") - return original_method(self, epoch) - - return wrapper - - def optimizer_zero_grad_method_wrapper( original_method, set_to_none: bool = True ) -> None: def wrapper(self) -> Any: for group in self.param_groups: + """ + group: + dict_keys(['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov', 'maximize', 'foreach', 'differentiable', 'fused', 'initial_lr']) + """ lr = group["lr"] Logger.info(f"[PPROBE] Iteration optimizer lr ===> {lr}") return original_method(self, set_to_none)