Skip to content

Commit

Permalink
add lr hook
Browse files Browse the repository at this point in the history
  • Loading branch information
clemente0731 committed Aug 8, 2024
1 parent e99b723 commit 3329ac7
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 14 deletions.
29 changes: 17 additions & 12 deletions pprobe/bootstrap/hook_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
func_torch_step_count_wrapper,
dataloader_next_method_wrapper,
)
from pprobe.bootstrap.hooks.pytorch_optim import (
lr_scheduler_step_method_wrapper,
optimizer_zero_grad_method_wrapper,
)
from pprobe.bootstrap.hooks.pytorch_perf import func_torch_device_conversion_wrapper


Expand Down Expand Up @@ -84,45 +88,34 @@ def check_and_run_hook(self, module_fullname):
# torch part
if module_fullname == "torch":
if self.torch_reproduce_enabled:
# TODO
pass
if self.torch_catch_step_enabled:
self.run_torch_catch_step_hook()
if self.torch_catch_loss_enabled:
# TODO
pass
if self.torch_catch_lr_enabled:
# TODO
pass
self.run_torch_catch_lr_hook()
if self.torch_dump_op_enabled:
self.run_torch_func_hook()
if self.torch_dump_aten_enabled:
# TODO
pass
if self.torch_dump_dist_enabled:
self.run_torch_dist_hook()
if self.torch_dump_module_enabled:
self.run_torch_module_hook()
if self.torch_dump_optim_enabled:
# TODO
pass
if self.torch_catch_memory_enabled:
# TODO
pass
if self.torch_test_dump_op_enabled:
# TODO
pass
if self.torch_test_dump_dist_enabled:
# TODO
pass
if self.torch_test_dump_module_enabled:
# TODO
pass
if self.torch_perf_issue_enabled:
# TODO
pass
if self.torch_torch_trace_file_enabled:
# TODO
pass
else:
self.print_warning()
Expand Down Expand Up @@ -246,6 +239,18 @@ def run_torch_catch_step_hook(self):
)
)

def run_torch_catch_lr_hook(self):
Logger.info(f"[PPROBE] torch catch lr hook executed")
self.module.optim.lr_scheduler.LRScheduler.step = (
lr_scheduler_step_method_wrapper(
self.module.optim.lr_scheduler.LRScheduler.step
)
)

self.module.optim.Optimizer.zero_grad = optimizer_zero_grad_method_wrapper(
self.module.optim.Optimizer.zero_grad
)

def run_torch_perf_hook(self):
Logger.info(f"[PPROBE] torch perf hook executed")

Expand Down
25 changes: 25 additions & 0 deletions pprobe/bootstrap/hooks/pytorch_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pprobe.utils.logging import Logger
from typing import Any, Callable, Optional


def lr_scheduler_step_method_wrapper(
original_method: Callable, epoch: Optional[int] = None
):
def wrapper(self) -> Any:
lr = self.get_lr()
Logger.info(f"[PPROBE] INIT LR ===> {lr}")
return original_method(self, epoch)

return wrapper


def optimizer_zero_grad_method_wrapper(
original_method, set_to_none: bool = True
) -> None:
def wrapper(self) -> Any:
for group in self.param_groups:
lr = group["lr"]
Logger.info(f"[PPROBE] Iteration optimizer lr ===> {lr}")
return original_method(self, set_to_none)

return wrapper
4 changes: 3 additions & 1 deletion pprobe/tests/xtest_torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,9 @@ def run_validate(loader, base_progress=0):
progress.display(i + 1)

if i >= 2:
print(f"MODEL EVAL FINISH {args.arch}: time duration:{time.time()-ST}")
print(
f"MODEL EVAL FINISH {args.arch}: time duration:{time.time()-ST}"
)
import sys

sys.exit(0)
Expand Down
3 changes: 2 additions & 1 deletion script/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ cd pprobe/tests

PPROBE --enable PPROBE_ENABLE
PPROBE --enable TORCH_DUMP_MODULE
PPROBE --enable TORCH_CATCH_STEP
PPROBE --enable TORCH_CATCH_STEP
PPROBE --enable TORCH_CATCH_LR

PPROBE_ENABLE=1 python xtest_torchvision_model.py -a resnet50 --epochs 1 -b 12 -p 1 --seed 42 --dummy
PPROBE_ENABLE=1 python xtest_torchvision_model.py -a resnet50 --epochs 1 -b 12 -p 1 --seed 42 --dummy --evaluate
Expand Down

0 comments on commit 3329ac7

Please sign in to comment.