Skip to content

Commit

Permalink
adding benchmarking (#3)
Browse files Browse the repository at this point in the history
* adding benchmarking


Co-authored-by: amainz <[email protected]>
Co-authored-by: adamomainz <[email protected]>
  • Loading branch information
3 people authored Aug 15, 2024
1 parent ac74d2f commit 792e99f
Show file tree
Hide file tree
Showing 26 changed files with 2,612 additions and 622 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
__pycache__/
*.py[cod]
.pytest_cache
**/.cache
**/meta-llama/**/*
2 changes: 2 additions & 0 deletions benchmarking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .profiler import Profiler
from .benchmark_utils import compare_benchmarks
28 changes: 28 additions & 0 deletions benchmarking/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Dict
import pandas as pd


def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
series_dict = {k: pd.Series(v.values()) for k, v in benchmarks.items()}
series_dict["kernel_path"] = pd.Series(
benchmarks[list(benchmarks.keys())[0]].keys()
)
series_dict["kernel"] = pd.Series(
[k.split(".")[-1] for k in series_dict["kernel_path"]]
)
df = pd.DataFrame()

for k, v in series_dict.items():
df[k] = v
columns = [c for c in df.columns if not "kernel" in c]
for i in range(len(columns)):
for j in range(i + 1, len(columns)):
# calculate the difference between the two columns
diff_col_name = f"{columns[i]}-{columns[j]}"
df[diff_col_name] = df[columns[i]] - df[columns[j]]
df.sort_values(by="kernel_path", inplace=True)
columns = [c for c in df.columns if not "kernel" in c]
columns = ["kernel", "kernel_path"] + columns
df = df[columns]
df.set_index("kernel", inplace=True)
return df
81 changes: 81 additions & 0 deletions benchmarking/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import contextlib
import time
from collections import defaultdict


class Profiler:
_instance = None

def __new__(cls, should_profile: bool = False, benchmark: bool = False):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.profiler = (
torch.profiler.profile(
record_shapes=True,
with_flops=True,
profile_memory=True,
with_stack=True,
with_modules=True,
)
if should_profile
else None
)
cls._instance.benchmark = benchmark
cls._instance.benchmark_vals = defaultdict(list)
cls._instance.function_stack = []

return cls._instance

@classmethod
def reset(cls):
cls._instance = None

@classmethod
def profiling_decorator(
cls,
record_name: str = None,
skip_profiling: bool = False,
skip_benchmark: bool = False,
):
def inner(func):
def wrapper(*args, **kwargs):
if not cls._instance or (skip_profiling and skip_benchmark):
return func(*args, **kwargs)
cls._instance.function_stack.append(record_name or func.__name__)
name = ".".join(cls._instance.function_stack)
if cls._instance.profiler and not skip_profiling:
cls._instance.profiler.start()
start_time = time.perf_counter()

with torch.profiler.record_function(name):
result = func(*args, **kwargs)

end_time = time.perf_counter()
if cls._instance.benchmark and not skip_benchmark:
cls._instance.benchmark_vals[name].append(end_time - start_time)
if cls._instance.profiler and not skip_profiling:
cls._instance.profiler.stop()
cls._instance.function_stack.pop()
return result

return wrapper

return inner

@classmethod
def step(cls):
if cls._instance and cls._instance.profiler:
cls._instance.profiler.step()

@classmethod
def get_benchmark_vals(cls):
if cls._instance and cls._instance.benchmark:
return {k: sum(v) / len(v) for k, v in cls._instance.benchmark_vals.items()}
return None

@classmethod
def get_profiling_data(cls):
if cls._instance and cls._instance.profiler:
return self.profiler.key_averages()
return None
10 changes: 9 additions & 1 deletion kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@
from .flash_attention import attention
from .matmul import _matmul, get_higher_dtype, matmul

__all__ = ["blocksparse", "_cross_entropy", "cross_entropy", "_matmul", "matmul", "attention", "get_higher_dtype"]
__all__ = [
"blocksparse",
"_cross_entropy",
"cross_entropy",
"_matmul",
"matmul",
"attention",
"get_higher_dtype",
]
Loading

0 comments on commit 792e99f

Please sign in to comment.