Skip to content

Commit

Permalink
update lr hook
Browse files Browse the repository at this point in the history
  • Loading branch information
clemente0731 committed Aug 8, 2024
1 parent 3329ac7 commit 2c939d6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 17 deletions.
6 changes: 0 additions & 6 deletions pprobe/bootstrap/hook_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,6 @@ def run_torch_catch_step_hook(self):

def run_torch_catch_lr_hook(self):
Logger.info(f"[PPROBE] torch catch lr hook executed")
self.module.optim.lr_scheduler.LRScheduler.step = (
lr_scheduler_step_method_wrapper(
self.module.optim.lr_scheduler.LRScheduler.step
)
)

self.module.optim.Optimizer.zero_grad = optimizer_zero_grad_method_wrapper(
self.module.optim.Optimizer.zero_grad
)
Expand Down
15 changes: 4 additions & 11 deletions pprobe/bootstrap/hooks/pytorch_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,15 @@
from typing import Any, Callable, Optional


def lr_scheduler_step_method_wrapper(
original_method: Callable, epoch: Optional[int] = None
):
def wrapper(self) -> Any:
lr = self.get_lr()
Logger.info(f"[PPROBE] INIT LR ===> {lr}")
return original_method(self, epoch)

return wrapper


def optimizer_zero_grad_method_wrapper(
original_method, set_to_none: bool = True
) -> None:
def wrapper(self) -> Any:
for group in self.param_groups:
"""
group:
dict_keys(['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov', 'maximize', 'foreach', 'differentiable', 'fused', 'initial_lr'])
"""
lr = group["lr"]
Logger.info(f"[PPROBE] Iteration optimizer lr ===> {lr}")
return original_method(self, set_to_none)
Expand Down

0 comments on commit 2c939d6

Please sign in to comment.