Skip to content

Commit

Permalink
Update hook logic to include a hook for torch backward
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyu17 authored and jiangyu17 committed Apr 3, 2024
1 parent dc58fc4 commit 0e87a9a
Show file tree
Hide file tree
Showing 3 changed files with 718 additions and 23 deletions.
37 changes: 22 additions & 15 deletions bootstrap/_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@
import sys
import time

_hook_modules = {'hello_torch'}
sys.path.append('../utils') # 将模块所在目录添加到搜索路径中
from utils.xlog import XLogger

_hook_modules = {'torch'}
counts = 0

class MetaPathFinder:

def find_module(self, fullname, path=None):
print('find_module {}'.format(fullname))
# print('find_module {}'.format(fullname))
if fullname in _hook_modules:
return MetaPathLoader()


class MetaPathLoader:

def load_module(self, fullname):
print('load_module {}'.format(fullname))
# print('load_module {}'.format(fullname))
# ``sys.modules`` 中保存的是已经导入过的 module
if fullname in sys.modules:
return sys.modules[fullname]

# 先从 sys.meta_path 中删除自定义的 finder
# 防止下面执行 import_module 的时候再次触发此 finder
# 从而出现递归调用的问题
Expand All @@ -38,20 +37,28 @@ def load_module(self, fullname):


def module_hook(fullname, module):
if fullname == 'hello_torch':
# monkey-patch
# 这里把 torch.add替换成torch.sub
module.torch_add = func_wrapper(module.torch_sub)
# print("module =========:", module)
# print("fullname =========:", fullname)
if fullname == "torch":
# torch.Tensor.backward 和 torch.autograd.backward 是等价的
module.autograd.backward = func_count_wrapper(module.autograd.backward)


def func_wrapper(func):
def func_count_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
print("开始函数 == func")

global counts
XLogger.info(f"Function Start == {func}")
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print("结束函数 == func")
print("花费时间 {}s".format(end - start))
counts += 1
XLogger.debug(f"steps {counts}: func args: {args}")
XLogger.debug(f"steps {counts}: func kwargs: {kwargs}")

XLogger.info("steps {}: {} takes {}s".format(counts, func, end - start))
XLogger.info(f"Function End == {func}")

return result
return wrapper
Loading

0 comments on commit 0e87a9a

Please sign in to comment.