Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
clemente0731 committed Aug 6, 2024
1 parent 6425ab8 commit 18f5e33
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
10 changes: 5 additions & 5 deletions pprobe/bootstrap/hook_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def check_and_run_hook(self, module_fullname):
# TODO
pass
if self.torch_dump_dist_enabled:
# TODO
pass
self.run_torch_dist_hook()
if self.torch_dump_module_enabled:
self.run_torch_module_hook()
if self.torch_dump_optim_enabled:
Expand Down Expand Up @@ -159,6 +158,7 @@ def run_torch_dist_hook(self):

# 1. Communication Operations
# torch.distributed.all_gather: Gathers data from all processes into a list.

# torch.distributed.all_reduce: Reduces data from all processes and broadcasts the result back to all processes.
# torch.distributed.broadcast: Broadcasts data from the root process to all other processes.
# torch.distributed.gather: Gathers data from all processes to the root process.
Expand All @@ -180,6 +180,9 @@ def run_torch_dist_hook(self):
# torch.distributed.barrier: Performs a global synchronization operation where all processes wait until all processes reach the synchronization point before continuing.
# torch.distributed.monitored_barrier: Similar to barrier, but supports timeouts and error reporting, useful for debugging and synchronization.

self.module.distributed.all_gather = func_torch_distributed_wrapper(
self.module.distributed.all_gather
)
self.module.distributed.broadcast = func_torch_distributed_wrapper(
self.module.distributed.broadcast
)
Expand All @@ -189,9 +192,6 @@ def run_torch_dist_hook(self):
self.module.distributed.reduce = func_torch_distributed_wrapper(
self.module.distributed.reduce
)
self.module.distributed.all_gather = func_torch_distributed_wrapper(
self.module.distributed.all_gather
)
self.module.distributed.gather = func_torch_distributed_wrapper(
self.module.distributed.gather
)
Expand Down
10 changes: 9 additions & 1 deletion pprobe/bootstrap/hooks/pytorch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@
import functools
from pprobe.utils.logging import Logger

"""
shape: val.shape
dtype: val.dtype
mean: val.mean().item()
std: val.mean().item()
"""


def func_torch_distributed_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if callable(func):
result = func(*args, **kwargs)
# TODO: Refine the handling of each function.
if func.__module__ == 'torch.distributed' and func.__name__ == 'all_gather':
print("xxxxxxx", func)
if isinstance(args, tuple):
Logger.info(f"[PPROBE] torch.distributed.{func.__qualname__}")
else:
Expand Down

0 comments on commit 18f5e33

Please sign in to comment.