Skip to content

Commit

Permalink
refine && fix
Browse files Browse the repository at this point in the history
  • Loading branch information
clemente0731 committed Aug 8, 2024
1 parent 66e44d8 commit e99b723
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
10 changes: 7 additions & 3 deletions pprobe/bootstrap/hooks/pytorch_catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from typing import Any, Callable

func_counts = 0


def func_torch_step_count_wrapper(func):
"""
torch.autograd.backward / torch.Tensor.backward
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
global func_counts
Expand All @@ -30,8 +33,9 @@ def dataloader_next_method_wrapper(original_next: Callable) -> Callable:
"""
Decorator function to wrap the original __next__ method and add additional debug information.
"""

def wrapper(self) -> Any:
Logger.info("[PPROBE] Iteration count ===>:", getattr(self, '_num_yielded', 'N/A'))
Logger.info(f"[PPROBE] Iteration count ===>:{self._num_yielded}")
return original_next(self)
return wrapper

return wrapper
2 changes: 1 addition & 1 deletion pprobe/bootstrap/hooks/pytorch_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def func_torch_distributed_wrapper(func):
def wrapper(*args, **kwargs):
if callable(func):
result = func(*args, **kwargs)
if func.__module__ == 'torch.distributed' and func.__name__ == 'all_gather':
if func.__module__ == "torch.distributed" and func.__name__ == "all_gather":
print("xxxxxxx", func)
if isinstance(args, tuple):
Logger.info(f"[PPROBE] torch.distributed.{func.__qualname__}")
Expand Down
7 changes: 7 additions & 0 deletions pprobe/tests/xtest_torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def validate(val_loader, model, criterion, args):
def run_validate(loader, base_progress=0):
with torch.no_grad():
end = time.time()
ST = time.time()
for i, (images, target) in enumerate(loader):
i = base_progress + i
if args.gpu is not None and torch.cuda.is_available():
Expand Down Expand Up @@ -504,6 +505,12 @@ def run_validate(loader, base_progress=0):
if i % args.print_freq == 0:
progress.display(i + 1)

if i >= 2:
print(f"MODEL EVAL FINISH {args.arch}: time duration:{time.time()-ST}")
import sys

sys.exit(0)

batch_time = AverageMeter("Time", ":6.3f", Summary.NONE)
losses = AverageMeter("Loss", ":.4e", Summary.NONE)
top1 = AverageMeter("Acc@1", ":6.2f", Summary.AVERAGE)
Expand Down
4 changes: 3 additions & 1 deletion script/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ cd pprobe/tests

PPROBE --enable PPROBE_ENABLE
PPROBE --enable TORCH_DUMP_MODULE
PPROBE --enable TORCH_CATCH_STEP

PPROBE_ENABLE=1 python xtest_torchvision_model.py -a resnet50 --epochs 1 -b 12 -p 1 --seed 42 --dummy
PPROBE_ENABLE=1 python xtest_torchvision_model.py -a resnet50 --epochs 1 -b 12 -p 1 --seed 42 --dummy
PPROBE_ENABLE=1 python xtest_torchvision_model.py -a resnet50 --epochs 1 -b 12 -p 1 --seed 42 --dummy --evaluate
PPROBE_ENABLE=1 python xtest_device_conversion_detection.py

PPROBE --reset

0 comments on commit e99b723

Please sign in to comment.