From ef461ef8da0f5bdecbdceedb579b2bb291e31734 Mon Sep 17 00:00:00 2001 From: clemente0731 Date: Wed, 3 Jul 2024 14:40:19 +0800 Subject: [PATCH] feat: Add torch module hook functionality - Added `torch_dump_module_enabled` toggle to PProbeSetup to enable torch module hooks. - Introduced `TorchModuleContext` in `pytorch_module.py` to scan and register hooks for all `torch.nn.Module` instances. - Implemented the `_scan_and_register_hooks` method to identify and register hooks for torch modules. - Created the `run_torch_module_hook` method in PProbeSetup to initialize the module hook context. - Updated `script/test.sh` to include `TORCH_DUMP_MODULE` toggle for testing module hooks. This update allows for the capture of input and output tensors for each forward pass of `torch.nn.Module` instances, providing deeper insight into model behavior during execution. --- pprobe/bootstrap/hook_setup.py | 101 +++++++++++++----- pprobe/bootstrap/hooks/pytorch_aten_op.py | 0 pprobe/bootstrap/hooks/pytorch_catch.py | 6 +- pprobe/bootstrap/hooks/pytorch_dist.py | 4 +- pprobe/bootstrap/hooks/pytorch_func_op.py | 17 ++- pprobe/bootstrap/hooks/pytorch_module.py | 53 +++++++++ pprobe/bootstrap/hooks/pytorch_perf.py | 16 ++- pprobe/bootstrap/sitecustomize.py | 8 +- .../xtest_device_conversion_detection.py | 2 +- pprobe/tests/xtest_torchvision_model.py | 3 +- pprobe/toggle/cli.py | 18 ++-- pprobe/toggle/tabulate.py | 2 +- pprobe/utils/logging.py | 3 +- pprobe/utils/trace.py | 6 +- script/build.sh | 16 +++ script/test.sh | 5 + setup.py | 9 +- 17 files changed, 208 insertions(+), 61 deletions(-) create mode 100644 pprobe/bootstrap/hooks/pytorch_aten_op.py diff --git a/pprobe/bootstrap/hook_setup.py b/pprobe/bootstrap/hook_setup.py index ce7846f..4e3e960 100644 --- a/pprobe/bootstrap/hook_setup.py +++ b/pprobe/bootstrap/hook_setup.py @@ -8,20 +8,26 @@ from pprobe.bootstrap.hooks.pytorch_perf import func_torch_device_conversion_wrapper - -_hook_modules = {'torch'} +_hook_modules = {"torch"} class PProbeSetup: def __init__(self, module, module_fullname): - self.module = module + self.module = module self.pprobe_toggle = ToggleManager() self.pprobe_enabled = self.pprobe_toggle.get_toggle("PPROBE_ENABLE") - self.torch_catch_step_enabled = self.pprobe_toggle.get_toggle("TORCH_CATCH_STEP") + self.torch_catch_step_enabled = self.pprobe_toggle.get_toggle( + "TORCH_CATCH_STEP" + ) self.torch_reproduce_enabled = self.pprobe_toggle.get_toggle("TORCH_REPRODUCE") self.torch_dump_op_enabled = self.pprobe_toggle.get_toggle("TORCH_DUMP_OP") self.torch_dump_dist_enabled = self.pprobe_toggle.get_toggle("TORCH_DUMP_DIST") - self.torch_perf_issue_enabled = self.pprobe_toggle.get_toggle("TORCH_PERF_ISSUE") + self.torch_dump_module_enabled = self.pprobe_toggle.get_toggle( + "TORCH_DUMP_MODULE" + ) + self.torch_perf_issue_enabled = self.pprobe_toggle.get_toggle( + "TORCH_PERF_ISSUE" + ) self.check_and_run_hook(module_fullname) @@ -36,6 +42,8 @@ def check_and_run_hook(self, module_fullname): self.run_torch_reproduce_hook() if self.torch_dump_op_enabled: self.run_torch_func_hook() + if self.torch_dump_module_enabled: + self.run_torch_module_hook() if self.torch_dump_dist_enabled: self.run_torch_func_hook() if self.torch_perf_issue_enabled: @@ -51,26 +59,54 @@ def run_generic_hook(self): def run_torch_func_hook(self): from pprobe.bootstrap.hooks import pytorch_func_op + context = pytorch_func_op.TorchFunctionContext() context.__enter__() Logger.info(f"[PPROBE] torch function hook executed") + def run_torch_module_hook(self): + from pprobe.bootstrap.hooks import pytorch_module + + context = pytorch_module.TorchModuleContext() + context.__enter__() + Logger.info(f"[PPROBE] torch module hook executed") + def run_torch_dist_hook(self): Logger.info(f"[PPROBE] torch dist hook executed") ################################################### ## torch.distributed part ################################################### - self.module.distributed.broadcast = func_torch_distributed_wrapper(self.module.distributed.broadcast) - self.module.distributed.all_reduce = func_torch_distributed_wrapper(self.module.distributed.all_reduce) - self.module.distributed.reduce = func_torch_distributed_wrapper(self.module.distributed.reduce) - self.module.distributed.all_gather = func_torch_distributed_wrapper(self.module.distributed.all_gather) - self.module.distributed.gather = func_torch_distributed_wrapper(self.module.distributed.gather) - self.module.distributed.scatter = func_torch_distributed_wrapper(self.module.distributed.scatter) - self.module.distributed.reduce_scatter = func_torch_distributed_wrapper(self.module.distributed.reduce_scatter) - self.module.distributed.send = func_torch_distributed_wrapper(self.module.distributed.send) - self.module.distributed.recv = func_torch_distributed_wrapper(self.module.distributed.recv) - self.module.distributed.barrier = func_torch_distributed_wrapper(self.module.distributed.barrier) + self.module.distributed.broadcast = func_torch_distributed_wrapper( + self.module.distributed.broadcast + ) + self.module.distributed.all_reduce = func_torch_distributed_wrapper( + self.module.distributed.all_reduce + ) + self.module.distributed.reduce = func_torch_distributed_wrapper( + self.module.distributed.reduce + ) + self.module.distributed.all_gather = func_torch_distributed_wrapper( + self.module.distributed.all_gather + ) + self.module.distributed.gather = func_torch_distributed_wrapper( + self.module.distributed.gather + ) + self.module.distributed.scatter = func_torch_distributed_wrapper( + self.module.distributed.scatter + ) + self.module.distributed.reduce_scatter = func_torch_distributed_wrapper( + self.module.distributed.reduce_scatter + ) + self.module.distributed.send = func_torch_distributed_wrapper( + self.module.distributed.send + ) + self.module.distributed.recv = func_torch_distributed_wrapper( + self.module.distributed.recv + ) + self.module.distributed.barrier = func_torch_distributed_wrapper( + self.module.distributed.barrier + ) def run_torch_reproduce_hook(self): # Add the logic for torch reproduce hook @@ -78,45 +114,53 @@ def run_torch_reproduce_hook(self): def run_torch_catch_step_hook(self): ################################################### - # torch.autograd.backward / torch.Tensor.backward + # torch.autograd.backward / torch.Tensor.backward ################################################### # Add the logic for torch catch_step hook Logger.info(f"[PPROBE] torch catch step hook executed") - self.module.autograd.backward = func_torch_step_count_wrapper(self.module.autograd.backward) + self.module.autograd.backward = func_torch_step_count_wrapper( + self.module.autograd.backward + ) def run_torch_perf_hook(self): - Logger.info(f"[PPROBE] torch perf hook executed") ################################################### ## torch.Tensor.to part ################################################### - self.module.Tensor.to = func_torch_device_conversion_wrapper(self.module.Tensor.to) - self.module.Tensor.cpu = func_torch_device_conversion_wrapper(self.module.Tensor.cpu) - self.module.Tensor.cuda = func_torch_device_conversion_wrapper(self.module.Tensor.cuda) + self.module.Tensor.to = func_torch_device_conversion_wrapper( + self.module.Tensor.to + ) + self.module.Tensor.cpu = func_torch_device_conversion_wrapper( + self.module.Tensor.cpu + ) + self.module.Tensor.cuda = func_torch_device_conversion_wrapper( + self.module.Tensor.cuda + ) def print_warning(self): - if not getattr(self, 'warning_printed', False): - print("[PPROBE] Please set the environment variable PPROBE_ENABLE=1 to use pprobe.") - setattr(self, 'warning_printed', True) + if not getattr(self, "warning_printed", False): + print( + "[PPROBE] Please set the environment variable PPROBE_ENABLE=1 to use pprobe." + ) + setattr(self, "warning_printed", True) class MetaPathFinder: - def find_module(self, module_fullname, path=None): # Logger.info('find_module {}'.format(module_fullname)) if module_fullname in _hook_modules: return MetaPathLoader() -class MetaPathLoader: +class MetaPathLoader: def load_module(self, module_fullname): # Logger.info('load_module {}'.format(module_fullname)) # sys.modules中保存的是已经导入过的 module if module_fullname in sys.modules: return sys.modules[module_fullname] - + ################################################## # 先从 sys.meta_path 中删除自定义的 finder # 防止下面执行 import_module 的时候再次触发此 finder @@ -127,9 +171,10 @@ def load_module(self, module_fullname): # Logger.info(f"META-PATH-LOADER --> MODULE {module}") pprobe = PProbeSetup(module, module_fullname) - + sys.meta_path.insert(0, finder) return pprobe.module + sys.meta_path.insert(0, MetaPathFinder()) diff --git a/pprobe/bootstrap/hooks/pytorch_aten_op.py b/pprobe/bootstrap/hooks/pytorch_aten_op.py new file mode 100644 index 0000000..e69de29 diff --git a/pprobe/bootstrap/hooks/pytorch_catch.py b/pprobe/bootstrap/hooks/pytorch_catch.py index a5b6fda..4c5d925 100644 --- a/pprobe/bootstrap/hooks/pytorch_catch.py +++ b/pprobe/bootstrap/hooks/pytorch_catch.py @@ -4,9 +4,10 @@ func_counts = 0 + def func_torch_step_count_wrapper(func): ################################################### - # torch.autograd.backward / torch.Tensor.backward + # torch.autograd.backward / torch.Tensor.backward ################################################### @functools.wraps(func) def wrapper(*args, **kwargs): @@ -22,4 +23,5 @@ def wrapper(*args, **kwargs): else: # handle the case where func is not callable Logger.warn(f"func:{func} is not callable") - return wrapper \ No newline at end of file + + return wrapper diff --git a/pprobe/bootstrap/hooks/pytorch_dist.py b/pprobe/bootstrap/hooks/pytorch_dist.py index 862cc86..4617e8d 100644 --- a/pprobe/bootstrap/hooks/pytorch_dist.py +++ b/pprobe/bootstrap/hooks/pytorch_dist.py @@ -2,6 +2,7 @@ import functools from pprobe.utils.logging import Logger + def func_torch_distributed_wrapper(func): @functools.wraps(func) def wrapper(*args, **kwargs): @@ -16,4 +17,5 @@ def wrapper(*args, **kwargs): else: # handle the case where func is not callable Logger.warn(f"[PPROBE] func:{func} is not callable") - return wrapper \ No newline at end of file + + return wrapper diff --git a/pprobe/bootstrap/hooks/pytorch_func_op.py b/pprobe/bootstrap/hooks/pytorch_func_op.py index 62eb19c..04040a8 100644 --- a/pprobe/bootstrap/hooks/pytorch_func_op.py +++ b/pprobe/bootstrap/hooks/pytorch_func_op.py @@ -29,6 +29,7 @@ class TorchFunctionContext(TorchFunctionMode): context.__exit__() """ + def __init__(self): super().__init__() self.func_idx = 0 @@ -147,7 +148,9 @@ def __torch_function__(self, func, types, args, kwargs=None): output = func(*args, **(kwargs or {})) - if resolve_name(func) and any(keyword in resolve_name(func) for keyword in ["dtype", "shape"]): + if resolve_name(func) and any( + keyword in resolve_name(func) for keyword in ["dtype", "shape"] + ): # If the function name includes "torch.Tensor.dtype" or "torch.Tensor.shape", return the output value directly without further processing return output @@ -204,6 +207,7 @@ class TorchFunctionMiniContext(TorchFunctionMode): context.__exit__() """ + def __init__(self): super().__init__() self.func_idx = 0 @@ -243,7 +247,9 @@ def generate_filename(self): # Check if it's multi-GPU multi-rank or single-GPU if self.is_multi_gpu_multi_rank(): rank = dist.get_rank() - filename = f"{now.strftime('%Y%m%d%H%M')}_mini_function_rank_{rank}_dump.csv" + filename = ( + f"{now.strftime('%Y%m%d%H%M')}_mini_function_rank_{rank}_dump.csv" + ) else: filename = f"{now.strftime('%Y%m%d%H%M')}_mini_function_rank_0_dump.csv" @@ -260,11 +266,12 @@ def __torch_function__(self, func, types, args, kwargs=None): output = func(*args, **(kwargs or {})) - - if resolve_name(func) and any(keyword in resolve_name(func) for keyword in ["dtype", "shape"]): + if resolve_name(func) and any( + keyword in resolve_name(func) for keyword in ["dtype", "shape"] + ): # If the function name includes "torch.Tensor.dtype" or "torch.Tensor.shape", return the output value directly without further processing return output - + print(f"{resolve_name(func)}() ===== type(output) {type(output)}", flush=True) # 填充缺失值 diff --git a/pprobe/bootstrap/hooks/pytorch_module.py b/pprobe/bootstrap/hooks/pytorch_module.py index e69de29..0b87c00 100644 --- a/pprobe/bootstrap/hooks/pytorch_module.py +++ b/pprobe/bootstrap/hooks/pytorch_module.py @@ -0,0 +1,53 @@ +import torch +from contextlib import ContextDecorator + + +class TorchModuleContext(ContextDecorator): + def __init__(self): + self.hooks = [] + + def __enter__(self): + self._scan_and_register_hooks() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._remove_hooks() + + def _scan_and_register_hooks(self): + """ + Scan the global namespace and register hooks for all torch.nn.Module instances. + """ + print("====== _scan_and_register_hooks =========") + for obj in globals().values(): + print(f"########################## {obj}") + if isinstance(obj, torch.nn.Module): + self._register_hooks_for_model(obj) + + def _register_hooks_for_model(self, model): + m_tuple = self.get_named_modules(model) + for name, m in m_tuple: + self._register_hook(name, m) + + def get_named_modules(self, module): + """ + Return a list of (name, module) tuples from the module. + """ + return list(module.named_modules()) + + def _register_hook(self, name, module): + """ + Register a hook for the given module. + """ + + # Define the hook function + def hook_fn(module, input, output): + print(f"Hook for {name}: input = {input}, output = {output}") + + # Register the forward hook and store the hook handle + handle = module.register_forward_hook(hook_fn) + self.hooks.append(handle) + + def _remove_hooks(self): + for handle in self.hooks: + handle.remove() + self.hooks = [] diff --git a/pprobe/bootstrap/hooks/pytorch_perf.py b/pprobe/bootstrap/hooks/pytorch_perf.py index e37909e..c426c0d 100644 --- a/pprobe/bootstrap/hooks/pytorch_perf.py +++ b/pprobe/bootstrap/hooks/pytorch_perf.py @@ -2,6 +2,7 @@ import functools from pprobe.utils.logging import Logger + def func_torch_device_conversion_wrapper(func): @functools.wraps(func) def wrapper(*args, **kwargs): @@ -9,16 +10,21 @@ def wrapper(*args, **kwargs): tensor_ret = func(*args, **kwargs) if func.__name__ == "to": - Logger.warn(f"[PPROBE] find device conversion call {func}, The tensor is conversion to {str(tensor_ret.device)}") + Logger.warn( + f"[PPROBE] find device conversion call {func}, The tensor is conversion to {str(tensor_ret.device)}" + ) elif func.__name__ == "cpu": - Logger.warn(f"[PPROBE] find device conversion call {func}, The tensor is on CPU") + Logger.warn( + f"[PPROBE] find device conversion call {func}, The tensor is on CPU" + ) elif func.__name__ == "cuda": - Logger.info(f"[PPROBE] find device conversion call {func}, The tensor is on {str(tensor_ret.device)}") + Logger.info( + f"[PPROBE] find device conversion call {func}, The tensor is on {str(tensor_ret.device)}" + ) return tensor_ret else: # handle the case where func is not callable Logger.warn(f"func:{func} is not callable") - return wrapper - + return wrapper diff --git a/pprobe/bootstrap/sitecustomize.py b/pprobe/bootstrap/sitecustomize.py index e91c059..cf1fd6d 100644 --- a/pprobe/bootstrap/sitecustomize.py +++ b/pprobe/bootstrap/sitecustomize.py @@ -1,5 +1,6 @@ import os + def check_and_run_hook(): """ Check the environment variable. If PPROBE is enabled, then execute the _hook function; @@ -13,9 +14,10 @@ def check_and_run_hook(): from pprobe.bootstrap import hook_setup else: # Print the warning message only once - if not getattr(check_and_run_hook, 'warning_printed', False): + if not getattr(check_and_run_hook, "warning_printed", False): # print("[PPROBE] Please set the environment variable PPROBE_ENABLE=1/2/3/4 to use pprobe.") - setattr(check_and_run_hook, 'warning_printed', True) + setattr(check_and_run_hook, "warning_printed", True) + # Call the function to check and run the hook -check_and_run_hook() \ No newline at end of file +check_and_run_hook() diff --git a/pprobe/tests/xtest_device_conversion_detection.py b/pprobe/tests/xtest_device_conversion_detection.py index f80ef52..f1f9a66 100644 --- a/pprobe/tests/xtest_device_conversion_detection.py +++ b/pprobe/tests/xtest_device_conversion_detection.py @@ -27,4 +27,4 @@ # Method 3: Using model.cpu() model_to_cpu_3 = model.cpu() -print(f"Using device: {device}") \ No newline at end of file +print(f"Using device: {device}") diff --git a/pprobe/tests/xtest_torchvision_model.py b/pprobe/tests/xtest_torchvision_model.py index ee76280..48cfc4b 100644 --- a/pprobe/tests/xtest_torchvision_model.py +++ b/pprobe/tests/xtest_torchvision_model.py @@ -469,6 +469,7 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args): if i >= 2: print(f"MODEL TRAIN FINISH {args.arch}: time duration:{time.time()-ST}") import sys + sys.exit(0) @@ -652,4 +653,4 @@ def accuracy(output, target, topk=(1,)): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pprobe/toggle/cli.py b/pprobe/toggle/cli.py index 874d876..876467f 100644 --- a/pprobe/toggle/cli.py +++ b/pprobe/toggle/cli.py @@ -1,4 +1,3 @@ - # !/usr/bin/env python """ @@ -52,7 +51,7 @@ # default_status: str -class ToggleManager(): +class ToggleManager: def __init__(self): self.default_toggle = collections.OrderedDict() self.running_toggle = collections.OrderedDict() @@ -129,17 +128,21 @@ def reset_toggle(self): self._save_toggles_to_file(self.running_toggle_path, self.running_toggle) self.show_status() - def _save_toggles_to_file(self, file_path, toggle_dict): try: - with file_path.open('w') as file: + with file_path.open("w") as file: for name, value in toggle_dict.items(): - value_str = "true" if value is True else "false" if value is False else value + value_str = ( + "true" + if value is True + else "false" + if value is False + else value + ) file.write(f"{name}={value_str}\n") except Exception as e: print(f"Error writing to {file_path}: {e}") - def show_status(self): """ when printing: @@ -186,7 +189,6 @@ def main(): parser.add_argument("--reset", action="store_true", help="Reset the toggle") args = parser.parse_args() - toggle_instance = ToggleManager() if args.enable: @@ -210,4 +212,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pprobe/toggle/tabulate.py b/pprobe/toggle/tabulate.py index 63e5393..80c3fb3 100644 --- a/pprobe/toggle/tabulate.py +++ b/pprobe/toggle/tabulate.py @@ -1800,4 +1800,4 @@ def _pprint_file(fobject, headers, tablefmt, sep, floatfmt, file, colalign): if __name__ == "__main__": - _main() \ No newline at end of file + _main() diff --git a/pprobe/utils/logging.py b/pprobe/utils/logging.py index f3e3429..86e52aa 100644 --- a/pprobe/utils/logging.py +++ b/pprobe/utils/logging.py @@ -17,10 +17,11 @@ class Color: END = "\033[0m" -class Logger(): +class Logger: """ implementation of Logger """ + @staticmethod def print_c(msg, color: Color): print(color + msg + Color.END) diff --git a/pprobe/utils/trace.py b/pprobe/utils/trace.py index 4965fcb..b686ee2 100644 --- a/pprobe/utils/trace.py +++ b/pprobe/utils/trace.py @@ -1,6 +1,6 @@ - import traceback + def trace_function_call(): """ Trace the call stack and log each entry except the last two frames. @@ -17,4 +17,6 @@ def trace_function_call(): stack_trace = traceback.extract_stack() # Print the stack trace information for stack_entry in stack_trace[:-2]: - Logger.warn(f"\t\t Trace File: {stack_entry.filename}, Line: {stack_entry.lineno}") + Logger.warn( + f"\t\t Trace File: {stack_entry.filename}, Line: {stack_entry.lineno}" + ) diff --git a/script/build.sh b/script/build.sh index d716385..cae6025 100644 --- a/script/build.sh +++ b/script/build.sh @@ -10,6 +10,22 @@ cd "${project_dir}" # 清理项目目录 # git clean -dxf + + +# Install black globally if not already installed +if ! command -v black &> /dev/null +then + echo "Installing black..." + pip install black || true +else + echo "Black is already installed." +fi + +# Format all Python files in the project using black +echo "Formatting Python files with black..." +black . || true + + # 构建 wheel 包 python setup.py bdist_wheel diff --git a/script/test.sh b/script/test.sh index 8769783..e38368a 100644 --- a/script/test.sh +++ b/script/test.sh @@ -9,5 +9,10 @@ cd "${project_dir}" cd pprobe/tests +PPROBE --enable PPROBE_ENABLE +PPROBE --enable TORCH_DUMP_MODULE + PPROBE_ENABLE=1 python xtest_torchvision_model.py -a resnet50 --epochs 1 -b 12 -p 1 --seed 42 --dummy PPROBE_ENABLE=1 python xtest_device_conversion_detection.py + +PPROBE --reset \ No newline at end of file diff --git a/setup.py b/setup.py index a70251d..d06466e 100644 --- a/setup.py +++ b/setup.py @@ -19,8 +19,10 @@ def read_requirements(file_path): print("pprobe install_requires:", f.read().splitlines()) return f.read().splitlines() + class build_py_with_pth_file(build_py): """Include the .pth file for this project, in the generated wheel.""" + def run(self): super().run() @@ -34,8 +36,8 @@ def copy_pth(self): self.copy_file(location_in_source_tree, outfile, preserve_mode=0) def copy_toggle(self): - src_file = 'pprobe/toggle/hook.toggle.default' - dst_file = 'pprobe/toggle/hook.toggle.running' + src_file = "pprobe/toggle/hook.toggle.default" + dst_file = "pprobe/toggle/hook.toggle.running" dst_build_file = os.path.join(self.build_lib, dst_file) try: shutil.copyfile(src_file, dst_build_file) @@ -43,6 +45,7 @@ def copy_toggle(self): except FileNotFoundError: print(f"Source file {src_file} does not exist") + setup( name="pprobe", version="1.0.0", @@ -67,4 +70,4 @@ def copy_toggle(self): install_requires=read_requirements("./requirements.txt"), zip_safe=False, cmdclass={"build_py": build_py_with_pth_file}, -) \ No newline at end of file +)