-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
19293fe
commit 7ed0603
Showing
17 changed files
with
302 additions
and
205 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import importlib | ||
import sys | ||
from pprobe.utils.logging import Logger | ||
from pprobe.toggle.cli import ToggleManager | ||
|
||
from pprobe.bootstrap.hooks.pytorch_catch import func_torch_step_count_wrapper | ||
from pprobe.bootstrap.hooks.pytorch_dist import func_torch_distributed_wrapper | ||
from pprobe.bootstrap.hooks.pytorch_perf import func_torch_device_conversion_wrapper | ||
|
||
|
||
|
||
_hook_modules = {'torch'} | ||
|
||
|
||
class PProbeSetup: | ||
def __init__(self, module, module_fullname): | ||
self.module = module | ||
self.pprobe_toggle = ToggleManager() | ||
self.pprobe_enabled = self.pprobe_toggle.get_toggle("PPROBE_ENABLE") | ||
self.torch_catch_step_enabled = self.pprobe_toggle.get_toggle("TORCH_CATCH_STEP") | ||
self.torch_reproduce_enabled = self.pprobe_toggle.get_toggle("TORCH_REPRODUCE") | ||
self.torch_dump_op_enabled = self.pprobe_toggle.get_toggle("TORCH_DUMP_OP") | ||
self.torch_dump_dist_enabled = self.pprobe_toggle.get_toggle("TORCH_DUMP_DIST") | ||
self.torch_perf_issue_enabled = self.pprobe_toggle.get_toggle("TORCH_PERF_ISSUE") | ||
|
||
self.check_and_run_hook(module_fullname) | ||
|
||
def check_and_run_hook(self, module_fullname): | ||
if self.pprobe_enabled: | ||
self.run_generic_hook() | ||
# torch part | ||
if module_fullname == "torch": | ||
if self.torch_catch_step_enabled: | ||
self.run_torch_catch_step_hook() | ||
if self.torch_reproduce_enabled: | ||
self.run_torch_reproduce_hook() | ||
if self.torch_dump_op_enabled: | ||
self.run_torch_func_hook() | ||
if self.torch_dump_dist_enabled: | ||
self.run_torch_func_hook() | ||
if self.torch_perf_issue_enabled: | ||
pass | ||
else: | ||
self.print_warning() | ||
|
||
def run_generic_hook(self): | ||
""" | ||
place_holder | ||
""" | ||
pass | ||
|
||
def run_torch_func_hook(self): | ||
from pprobe.bootstrap.hooks import pytorch_func_op | ||
context = pytorch_func_op.TorchFunctionContext() | ||
context.__enter__() | ||
Logger.info(f"[PPROBE] torch function hook executed") | ||
|
||
def run_torch_dist_hook(self): | ||
Logger.info(f"[PPROBE] torch dist hook executed") | ||
|
||
################################################### | ||
## torch.distributed part | ||
################################################### | ||
self.module.distributed.broadcast = func_torch_distributed_wrapper(self.module.distributed.broadcast) | ||
self.module.distributed.all_reduce = func_torch_distributed_wrapper(self.module.distributed.all_reduce) | ||
self.module.distributed.reduce = func_torch_distributed_wrapper(self.module.distributed.reduce) | ||
self.module.distributed.all_gather = func_torch_distributed_wrapper(self.module.distributed.all_gather) | ||
self.module.distributed.gather = func_torch_distributed_wrapper(self.module.distributed.gather) | ||
self.module.distributed.scatter = func_torch_distributed_wrapper(self.module.distributed.scatter) | ||
self.module.distributed.reduce_scatter = func_torch_distributed_wrapper(self.module.distributed.reduce_scatter) | ||
self.module.distributed.send = func_torch_distributed_wrapper(self.module.distributed.send) | ||
self.module.distributed.recv = func_torch_distributed_wrapper(self.module.distributed.recv) | ||
self.module.distributed.barrier = func_torch_distributed_wrapper(self.module.distributed.barrier) | ||
|
||
def run_torch_reproduce_hook(self): | ||
# Add the logic for torch reproduce hook | ||
Logger.info(f"[PPROBE] torch reproduce hook executed") | ||
|
||
def run_torch_catch_step_hook(self): | ||
################################################### | ||
# torch.autograd.backward / torch.Tensor.backward | ||
################################################### | ||
# Add the logic for torch catch_step hook | ||
Logger.info(f"[PPROBE] torch catch step hook executed") | ||
self.module.autograd.backward = func_torch_step_count_wrapper(self.module.autograd.backward) | ||
|
||
def run_torch_perf_hook(self): | ||
|
||
Logger.info(f"[PPROBE] torch perf hook executed") | ||
|
||
################################################### | ||
## torch.Tensor.to part | ||
################################################### | ||
self.module.Tensor.to = func_torch_device_conversion_wrapper(self.module.Tensor.to) | ||
self.module.Tensor.cpu = func_torch_device_conversion_wrapper(self.module.Tensor.cpu) | ||
self.module.Tensor.cuda = func_torch_device_conversion_wrapper(self.module.Tensor.cuda) | ||
|
||
def print_warning(self): | ||
if not getattr(self, 'warning_printed', False): | ||
print("[PPROBE] Please set the environment variable PPROBE_ENABLE=1 to use pprobe.") | ||
setattr(self, 'warning_printed', True) | ||
|
||
|
||
class MetaPathFinder: | ||
|
||
def find_module(self, module_fullname, path=None): | ||
# Logger.info('find_module {}'.format(module_fullname)) | ||
if module_fullname in _hook_modules: | ||
return MetaPathLoader() | ||
|
||
class MetaPathLoader: | ||
|
||
def load_module(self, module_fullname): | ||
# Logger.info('load_module {}'.format(module_fullname)) | ||
|
||
# sys.modules中保存的是已经导入过的 module | ||
if module_fullname in sys.modules: | ||
return sys.modules[module_fullname] | ||
|
||
################################################## | ||
# 先从 sys.meta_path 中删除自定义的 finder | ||
# 防止下面执行 import_module 的时候再次触发此 finder | ||
# 从而出现递归调用的问题 | ||
################################################## | ||
finder = sys.meta_path.pop(0) | ||
module = importlib.import_module(module_fullname) | ||
|
||
# Logger.info(f"META-PATH-LOADER --> MODULE {module}") | ||
pprobe = PProbeSetup(module, module_fullname) | ||
|
||
sys.meta_path.insert(0, finder) | ||
|
||
return pprobe.module | ||
|
||
sys.meta_path.insert(0, MetaPathFinder()) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import time | ||
import functools | ||
from pprobe.utils.logging import Logger | ||
|
||
func_counts = 0 | ||
|
||
def func_torch_step_count_wrapper(func): | ||
################################################### | ||
# torch.autograd.backward / torch.Tensor.backward | ||
################################################### | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
global func_counts | ||
|
||
if callable(func): | ||
start = time.time() | ||
result = func(*args, **kwargs) | ||
end = time.time() | ||
func_counts += 1 | ||
Logger.info(f"[PPROBE] func_name {func} --> counts {func_counts}") | ||
return result | ||
else: | ||
# handle the case where func is not callable | ||
Logger.warn(f"func:{func} is not callable") | ||
return wrapper |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import time | ||
import functools | ||
from pprobe.utils.logging import Logger | ||
|
||
def func_torch_distributed_wrapper(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
if callable(func): | ||
result = func(*args, **kwargs) | ||
# TODO: Refine the handling of each function. | ||
if isinstance(args, tuple): | ||
Logger.info(f"[PPROBE] torch.distributed.{func.__qualname__}") | ||
else: | ||
Logger.info(f"[PPROBE] torch.distributed.{func.__qualname__}") | ||
return result | ||
else: | ||
# handle the case where func is not callable | ||
Logger.warn(f"[PPROBE] func:{func} is not callable") | ||
return wrapper |
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import time | ||
import functools | ||
from pprobe.utils.logging import Logger | ||
|
||
def func_torch_device_conversion_wrapper(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
if callable(func): | ||
tensor_ret = func(*args, **kwargs) | ||
|
||
if func.__name__ == "to": | ||
Logger.warn(f"[PPROBE] find device conversion call {func}, The tensor is conversion to {str(tensor_ret.device)}") | ||
elif func.__name__ == "cpu": | ||
Logger.warn(f"[PPROBE] find device conversion call {func}, The tensor is on CPU") | ||
elif func.__name__ == "cuda": | ||
Logger.info(f"[PPROBE] find device conversion call {func}, The tensor is on {str(tensor_ret.device)}") | ||
|
||
return tensor_ret | ||
else: | ||
# handle the case where func is not callable | ||
Logger.warn(f"func:{func} is not callable") | ||
return wrapper | ||
|
||
|
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.