diff --git a/pprobe/bootstrap/hook_setup.py b/pprobe/bootstrap/hook_setup.py index 9d01422..2029964 100644 --- a/pprobe/bootstrap/hook_setup.py +++ b/pprobe/bootstrap/hook_setup.py @@ -7,6 +7,10 @@ func_torch_step_count_wrapper, dataloader_next_method_wrapper, ) +from pprobe.bootstrap.hooks.pytorch_optim import ( + lr_scheduler_step_method_wrapper, + optimizer_zero_grad_method_wrapper, +) from pprobe.bootstrap.hooks.pytorch_perf import func_torch_device_conversion_wrapper @@ -84,45 +88,34 @@ def check_and_run_hook(self, module_fullname): # torch part if module_fullname == "torch": if self.torch_reproduce_enabled: - # TODO pass if self.torch_catch_step_enabled: self.run_torch_catch_step_hook() if self.torch_catch_loss_enabled: - # TODO pass if self.torch_catch_lr_enabled: - # TODO - pass + self.run_torch_catch_lr_hook() if self.torch_dump_op_enabled: self.run_torch_func_hook() if self.torch_dump_aten_enabled: - # TODO pass if self.torch_dump_dist_enabled: self.run_torch_dist_hook() if self.torch_dump_module_enabled: self.run_torch_module_hook() if self.torch_dump_optim_enabled: - # TODO pass if self.torch_catch_memory_enabled: - # TODO pass if self.torch_test_dump_op_enabled: - # TODO pass if self.torch_test_dump_dist_enabled: - # TODO pass if self.torch_test_dump_module_enabled: - # TODO pass if self.torch_perf_issue_enabled: - # TODO pass if self.torch_torch_trace_file_enabled: - # TODO pass else: self.print_warning() @@ -246,6 +239,18 @@ def run_torch_catch_step_hook(self): ) ) + def run_torch_catch_lr_hook(self): + Logger.info(f"[PPROBE] torch catch lr hook executed") + self.module.optim.lr_scheduler.LRScheduler.step = ( + lr_scheduler_step_method_wrapper( + self.module.optim.lr_scheduler.LRScheduler.step + ) + ) + + self.module.optim.Optimizer.zero_grad = optimizer_zero_grad_method_wrapper( + self.module.optim.Optimizer.zero_grad + ) + def run_torch_perf_hook(self): Logger.info(f"[PPROBE] torch perf hook executed") diff --git a/pprobe/bootstrap/hooks/pytorch_optim.py b/pprobe/bootstrap/hooks/pytorch_optim.py index e69de29..0beb3a2 100644 --- a/pprobe/bootstrap/hooks/pytorch_optim.py +++ b/pprobe/bootstrap/hooks/pytorch_optim.py @@ -0,0 +1,25 @@ +from pprobe.utils.logging import Logger +from typing import Any, Callable, Optional + + +def lr_scheduler_step_method_wrapper( + original_method: Callable, epoch: Optional[int] = None +): + def wrapper(self) -> Any: + lr = self.get_lr() + Logger.info(f"[PPROBE] INIT LR ===> {lr}") + return original_method(self, epoch) + + return wrapper + + +def optimizer_zero_grad_method_wrapper( + original_method, set_to_none: bool = True +) -> None: + def wrapper(self) -> Any: + for group in self.param_groups: + lr = group["lr"] + Logger.info(f"[PPROBE] Iteration optimizer lr ===> {lr}") + return original_method(self, set_to_none) + + return wrapper diff --git a/pprobe/tests/xtest_torchvision_model.py b/pprobe/tests/xtest_torchvision_model.py index a11256d..cda1cd5 100644 --- a/pprobe/tests/xtest_torchvision_model.py +++ b/pprobe/tests/xtest_torchvision_model.py @@ -506,7 +506,9 @@ def run_validate(loader, base_progress=0): progress.display(i + 1) if i >= 2: - print(f"MODEL EVAL FINISH {args.arch}: time duration:{time.time()-ST}") + print( + f"MODEL EVAL FINISH {args.arch}: time duration:{time.time()-ST}" + ) import sys sys.exit(0) diff --git a/script/test.sh b/script/test.sh index 9c000b4..c72ca05 100644 --- a/script/test.sh +++ b/script/test.sh @@ -11,7 +11,8 @@ cd pprobe/tests PPROBE --enable PPROBE_ENABLE PPROBE --enable TORCH_DUMP_MODULE -PPROBE --enable TORCH_CATCH_STEP +PPROBE --enable TORCH_CATCH_STEP +PPROBE --enable TORCH_CATCH_LR PPROBE_ENABLE=1 python xtest_torchvision_model.py -a resnet50 --epochs 1 -b 12 -p 1 --seed 42 --dummy PPROBE_ENABLE=1 python xtest_torchvision_model.py -a resnet50 --epochs 1 -b 12 -p 1 --seed 42 --dummy --evaluate