diff --git a/pprobe/bootstrap/hook_setup.py b/pprobe/bootstrap/hook_setup.py index 73876d5..76cb79e 100644 --- a/pprobe/bootstrap/hook_setup.py +++ b/pprobe/bootstrap/hook_setup.py @@ -97,8 +97,7 @@ def check_and_run_hook(self, module_fullname): # TODO pass if self.torch_dump_dist_enabled: - # TODO - pass + self.run_torch_dist_hook() if self.torch_dump_module_enabled: self.run_torch_module_hook() if self.torch_dump_optim_enabled: @@ -159,6 +158,7 @@ def run_torch_dist_hook(self): # 1. Communication Operations # torch.distributed.all_gather: Gathers data from all processes into a list. + # torch.distributed.all_reduce: Reduces data from all processes and broadcasts the result back to all processes. # torch.distributed.broadcast: Broadcasts data from the root process to all other processes. # torch.distributed.gather: Gathers data from all processes to the root process. @@ -180,6 +180,9 @@ def run_torch_dist_hook(self): # torch.distributed.barrier: Performs a global synchronization operation where all processes wait until all processes reach the synchronization point before continuing. # torch.distributed.monitored_barrier: Similar to barrier, but supports timeouts and error reporting, useful for debugging and synchronization. + self.module.distributed.all_gather = func_torch_distributed_wrapper( + self.module.distributed.all_gather + ) self.module.distributed.broadcast = func_torch_distributed_wrapper( self.module.distributed.broadcast ) @@ -189,9 +192,6 @@ def run_torch_dist_hook(self): self.module.distributed.reduce = func_torch_distributed_wrapper( self.module.distributed.reduce ) - self.module.distributed.all_gather = func_torch_distributed_wrapper( - self.module.distributed.all_gather - ) self.module.distributed.gather = func_torch_distributed_wrapper( self.module.distributed.gather ) diff --git a/pprobe/bootstrap/hooks/pytorch_dist.py b/pprobe/bootstrap/hooks/pytorch_dist.py index 4617e8d..fd0c619 100644 --- a/pprobe/bootstrap/hooks/pytorch_dist.py +++ b/pprobe/bootstrap/hooks/pytorch_dist.py @@ -2,13 +2,21 @@ import functools from pprobe.utils.logging import Logger +""" +shape: val.shape +dtype: val.dtype +mean: val.mean().item() +std: val.mean().item() +""" + def func_torch_distributed_wrapper(func): @functools.wraps(func) def wrapper(*args, **kwargs): if callable(func): result = func(*args, **kwargs) - # TODO: Refine the handling of each function. + if func.__module__ == 'torch.distributed' and func.__name__ == 'all_gather': + print("xxxxxxx", func) if isinstance(args, tuple): Logger.info(f"[PPROBE] torch.distributed.{func.__qualname__}") else: