Skip to content

Commit

Permalink
refactor for code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
clemente0731 committed Jun 21, 2024
1 parent 19293fe commit 7ed0603
Show file tree
Hide file tree
Showing 17 changed files with 302 additions and 205 deletions.
38 changes: 20 additions & 18 deletions README.cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,26 @@ PPROBE --list
=================================================


+------------------+--------+---------+
| TOGGLE-NAMES | STATUS | DEFAULT |
+------------------+--------+---------+
| REPRODUCE | True | True |
| CATCH_STEP | False | False |
| CATCH_LOSS | False | False |
| CATCH_LR | False | False |
| DUMP_OP | False | False |
| DUMP_MODULE | False | False |
| DUMP_DIST | False | False |
| DUMP_MEMORY | False | False |
| TEST_DUMP_OP | False | False |
| TEST_DUMP_MODULE | False | False |
| TEST_DUMP_DIST | False | False |
| PERF_ISSUE | False | False |
| TRACE_FILE | False | False |
+------------------+--------+---------+
```
+------------------------+--------+---------+
| TOGGLE-NAMES | STATUS | DEFAULT |
+------------------------+--------+---------+
| PPROBE_ENABLE | True | False |
| TORCH_REPRODUCE | True | True |
| TORCH_CATCH_STEP | False | False |
| TORCH_CATCH_LOSS | False | False |
| TORCH_CATCH_LR | False | False |
| TORCH_DUMP_OP | True | False |
| TORCH_DUMP_MODULE | False | False |
| TORCH_DUMP_DIST | False | False |
| TORCH_DUMP_MEMORY | False | False |
| TORCH_TEST_DUMP_OP | False | False |
| TORCH_TEST_DUMP_MODULE | False | False |
| TORCH_TEST_DUMP_DIST | False | False |
| TORCH_PERF_ISSUE | False | False |
| TORCH_TRACE_FILE | False | False |
+------------------------+--------+---------+


```
**启用特定选项:**
Expand Down
36 changes: 19 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,25 @@ PPROBE --list
=================================================
+------------------+--------+---------+
| TOGGLE-NAMES | STATUS | DEFAULT |
+------------------+--------+---------+
| REPRODUCE | True | True |
| CATCH_STEP | False | False |
| CATCH_LOSS | False | False |
| CATCH_LR | False | False |
| DUMP_OP | False | False |
| DUMP_MODULE | False | False |
| DUMP_DIST | False | False |
| DUMP_MEMORY | False | False |
| TEST_DUMP_OP | False | False |
| TEST_DUMP_MODULE | False | False |
| TEST_DUMP_DIST | False | False |
| PERF_ISSUE | False | False |
| TRACE_FILE | False | False |
+------------------+--------+---------+
+------------------------+--------+---------+
| TOGGLE-NAMES | STATUS | DEFAULT |
+------------------------+--------+---------+
| PPROBE_ENABLE | True | False |
| TORCH_REPRODUCE | True | True |
| TORCH_CATCH_STEP | False | False |
| TORCH_CATCH_LOSS | False | False |
| TORCH_CATCH_LR | False | False |
| TORCH_DUMP_OP | True | False |
| TORCH_DUMP_MODULE | False | False |
| TORCH_DUMP_DIST | False | False |
| TORCH_DUMP_MEMORY | False | False |
| TORCH_TEST_DUMP_OP | False | False |
| TORCH_TEST_DUMP_MODULE | False | False |
| TORCH_TEST_DUMP_DIST | False | False |
| TORCH_PERF_ISSUE | False | False |
| TORCH_TRACE_FILE | False | False |
+------------------------+--------+---------+
```

## License
Expand Down
135 changes: 135 additions & 0 deletions pprobe/bootstrap/hook_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import importlib
import sys
from pprobe.utils.logging import Logger
from pprobe.toggle.cli import ToggleManager

from pprobe.bootstrap.hooks.pytorch_catch import func_torch_step_count_wrapper
from pprobe.bootstrap.hooks.pytorch_dist import func_torch_distributed_wrapper
from pprobe.bootstrap.hooks.pytorch_perf import func_torch_device_conversion_wrapper



_hook_modules = {'torch'}


class PProbeSetup:
def __init__(self, module, module_fullname):
self.module = module
self.pprobe_toggle = ToggleManager()
self.pprobe_enabled = self.pprobe_toggle.get_toggle("PPROBE_ENABLE")
self.torch_catch_step_enabled = self.pprobe_toggle.get_toggle("TORCH_CATCH_STEP")
self.torch_reproduce_enabled = self.pprobe_toggle.get_toggle("TORCH_REPRODUCE")
self.torch_dump_op_enabled = self.pprobe_toggle.get_toggle("TORCH_DUMP_OP")
self.torch_dump_dist_enabled = self.pprobe_toggle.get_toggle("TORCH_DUMP_DIST")
self.torch_perf_issue_enabled = self.pprobe_toggle.get_toggle("TORCH_PERF_ISSUE")

self.check_and_run_hook(module_fullname)

def check_and_run_hook(self, module_fullname):
if self.pprobe_enabled:
self.run_generic_hook()
# torch part
if module_fullname == "torch":
if self.torch_catch_step_enabled:
self.run_torch_catch_step_hook()
if self.torch_reproduce_enabled:
self.run_torch_reproduce_hook()
if self.torch_dump_op_enabled:
self.run_torch_func_hook()
if self.torch_dump_dist_enabled:
self.run_torch_func_hook()
if self.torch_perf_issue_enabled:
pass
else:
self.print_warning()

def run_generic_hook(self):
"""
place_holder
"""
pass

def run_torch_func_hook(self):
from pprobe.bootstrap.hooks import pytorch_func_op
context = pytorch_func_op.TorchFunctionContext()
context.__enter__()
Logger.info(f"[PPROBE] torch function hook executed")

def run_torch_dist_hook(self):
Logger.info(f"[PPROBE] torch dist hook executed")

###################################################
## torch.distributed part
###################################################
self.module.distributed.broadcast = func_torch_distributed_wrapper(self.module.distributed.broadcast)
self.module.distributed.all_reduce = func_torch_distributed_wrapper(self.module.distributed.all_reduce)
self.module.distributed.reduce = func_torch_distributed_wrapper(self.module.distributed.reduce)
self.module.distributed.all_gather = func_torch_distributed_wrapper(self.module.distributed.all_gather)
self.module.distributed.gather = func_torch_distributed_wrapper(self.module.distributed.gather)
self.module.distributed.scatter = func_torch_distributed_wrapper(self.module.distributed.scatter)
self.module.distributed.reduce_scatter = func_torch_distributed_wrapper(self.module.distributed.reduce_scatter)
self.module.distributed.send = func_torch_distributed_wrapper(self.module.distributed.send)
self.module.distributed.recv = func_torch_distributed_wrapper(self.module.distributed.recv)
self.module.distributed.barrier = func_torch_distributed_wrapper(self.module.distributed.barrier)

def run_torch_reproduce_hook(self):
# Add the logic for torch reproduce hook
Logger.info(f"[PPROBE] torch reproduce hook executed")

def run_torch_catch_step_hook(self):
###################################################
# torch.autograd.backward / torch.Tensor.backward
###################################################
# Add the logic for torch catch_step hook
Logger.info(f"[PPROBE] torch catch step hook executed")
self.module.autograd.backward = func_torch_step_count_wrapper(self.module.autograd.backward)

def run_torch_perf_hook(self):

Logger.info(f"[PPROBE] torch perf hook executed")

###################################################
## torch.Tensor.to part
###################################################
self.module.Tensor.to = func_torch_device_conversion_wrapper(self.module.Tensor.to)
self.module.Tensor.cpu = func_torch_device_conversion_wrapper(self.module.Tensor.cpu)
self.module.Tensor.cuda = func_torch_device_conversion_wrapper(self.module.Tensor.cuda)

def print_warning(self):
if not getattr(self, 'warning_printed', False):
print("[PPROBE] Please set the environment variable PPROBE_ENABLE=1 to use pprobe.")
setattr(self, 'warning_printed', True)


class MetaPathFinder:

def find_module(self, module_fullname, path=None):
# Logger.info('find_module {}'.format(module_fullname))
if module_fullname in _hook_modules:
return MetaPathLoader()

class MetaPathLoader:

def load_module(self, module_fullname):
# Logger.info('load_module {}'.format(module_fullname))

# sys.modules中保存的是已经导入过的 module
if module_fullname in sys.modules:
return sys.modules[module_fullname]

##################################################
# 先从 sys.meta_path 中删除自定义的 finder
# 防止下面执行 import_module 的时候再次触发此 finder
# 从而出现递归调用的问题
##################################################
finder = sys.meta_path.pop(0)
module = importlib.import_module(module_fullname)

# Logger.info(f"META-PATH-LOADER --> MODULE {module}")
pprobe = PProbeSetup(module, module_fullname)

sys.meta_path.insert(0, finder)

return pprobe.module

sys.meta_path.insert(0, MetaPathFinder())
Empty file.
25 changes: 25 additions & 0 deletions pprobe/bootstrap/hooks/pytorch_catch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import time
import functools
from pprobe.utils.logging import Logger

func_counts = 0

def func_torch_step_count_wrapper(func):
###################################################
# torch.autograd.backward / torch.Tensor.backward
###################################################
@functools.wraps(func)
def wrapper(*args, **kwargs):
global func_counts

if callable(func):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
func_counts += 1
Logger.info(f"[PPROBE] func_name {func} --> counts {func_counts}")
return result
else:
# handle the case where func is not callable
Logger.warn(f"func:{func} is not callable")
return wrapper
Empty file.
19 changes: 19 additions & 0 deletions pprobe/bootstrap/hooks/pytorch_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import time
import functools
from pprobe.utils.logging import Logger

def func_torch_distributed_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if callable(func):
result = func(*args, **kwargs)
# TODO: Refine the handling of each function.
if isinstance(args, tuple):
Logger.info(f"[PPROBE] torch.distributed.{func.__qualname__}")
else:
Logger.info(f"[PPROBE] torch.distributed.{func.__qualname__}")
return result
else:
# handle the case where func is not callable
Logger.warn(f"[PPROBE] func:{func} is not callable")
return wrapper
File renamed without changes.
Empty file.
24 changes: 24 additions & 0 deletions pprobe/bootstrap/hooks/pytorch_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import time
import functools
from pprobe.utils.logging import Logger

def func_torch_device_conversion_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if callable(func):
tensor_ret = func(*args, **kwargs)

if func.__name__ == "to":
Logger.warn(f"[PPROBE] find device conversion call {func}, The tensor is conversion to {str(tensor_ret.device)}")
elif func.__name__ == "cpu":
Logger.warn(f"[PPROBE] find device conversion call {func}, The tensor is on CPU")
elif func.__name__ == "cuda":
Logger.info(f"[PPROBE] find device conversion call {func}, The tensor is on {str(tensor_ret.device)}")

return tensor_ret
else:
# handle the case where func is not callable
Logger.warn(f"func:{func} is not callable")
return wrapper


100 changes: 0 additions & 100 deletions pprobe/bootstrap/pt_specific_hook.py

This file was deleted.

Loading

0 comments on commit 7ed0603

Please sign in to comment.