Skip to content

Commit

Permalink
Reduce runtime dependency on torch
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-huan committed Dec 25, 2024
1 parent daa7eb0 commit e8d7023
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 8 additions & 2 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
end_event = di.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
if hasattr(cache, "zero_"):
cache.zero_()
else:
cache[:] = 0
fn()
end_event.record()
di.synchronize()
Expand Down Expand Up @@ -152,7 +155,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
if hasattr(cache, "zero_"):
cache.zero_()
else:
cache[:] = 0
# record time of `fn`
start_event[i].record()
fn()
Expand Down
4 changes: 2 additions & 2 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,8 @@ def do_bench_cpu(*args, **kwargs):
return do_bench_cpu

def get_empty_cache_for_benchmark(self):
import torch
import numpy as np

# A typical LLC size for high-end server CPUs are ~400MB.
cache_size = 512 * 1024 * 1024
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cpu')
return np.empty(int(cache_size // 4), dtype=np.int32)

0 comments on commit e8d7023

Please sign in to comment.