-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathbenchmark.py
147 lines (134 loc) · 5.83 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Adapted from https://github.com/HazyResearch/hippo/blob/datasets/benchmark/utils.py
""" Useful functions for writing test code. """
import torch
import torch.utils.benchmark as benchmark
def benchmark_forward(fn, *inputs, repeats = 10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
if verbose:
print(desc, '- Forward pass')
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt='fn_amp(*inputs, **kwinputs)',
globals={'fn_amp': amp_wrapper, 'inputs': inputs, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
if verbose:
print(desc, '- Backward pass')
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape')
t = benchmark.Timer(
stmt='y.backward(grad, retain_graph=True)',
globals={'y': y, 'grad': grad},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
if verbose:
print(desc, '- Forward + Backward pass')
# y = fn(*inputs, **kwinputs)
# if grad is None:
# grad = torch.randn_like(y)
# else:
# if grad.shape != y.shape:
# raise RuntimeError('Grad shape does not match output shape')
# del y
def f(grad, *inputs, **kwinputs):
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape')
y.backward(grad, retain_graph=True)
t = benchmark.Timer(
stmt='f(grad, *inputs, **kwinputs)',
globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
return (
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
)
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
if backward:
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
g = torch.randn_like(fn(*inputs, **kwinputs))
for _ in range(30): # Warm up
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
# fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
out = fn(*inputs, **kwinputs)
# Backward should be done outside autocast
if backward:
out.backward(g)
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
# profile_memory=True,
with_stack=True,
) as prof:
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
out = fn(*inputs, **kwinputs)
if backward: out.backward(g)
if verbose:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print(prof.key_averages().table(row_limit=50))
if trace_filename is not None:
prof.export_chrome_trace(trace_filename)
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
fn(*inputs, **kwinputs)
torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
if verbose:
print(f'{desc} max memory: {mem}GB')
torch.cuda.empty_cache()
return mem