Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding benchmarking #3

Merged
merged 8 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
__pycache__/
*.py[cod]
.pytest_cache
**/.cache
**/meta-llama/**/*
2 changes: 2 additions & 0 deletions benchmarking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .profiler import Profiler
from .benchmark_utils import compare_benchmarks
28 changes: 28 additions & 0 deletions benchmarking/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Dict
import pandas as pd


def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
series_dict = {k: pd.Series(v.values()) for k, v in benchmarks.items()}
series_dict["kernel_path"] = pd.Series(
benchmarks[list(benchmarks.keys())[0]].keys()
)
series_dict["kernel"] = pd.Series(
[k.split(".")[-1] for k in series_dict["kernel_path"]]
)
df = pd.DataFrame()

for k, v in series_dict.items():
df[k] = v
columns = [c for c in df.columns if not "kernel" in c]
for i in range(len(columns)):
for j in range(i + 1, len(columns)):
# calculate the difference between the two columns
diff_col_name = f"{columns[i]}-{columns[j]}"
df[diff_col_name] = df[columns[i]] - df[columns[j]]
df.sort_values(by="kernel_path", inplace=True)
columns = [c for c in df.columns if not "kernel" in c]
columns = ["kernel", "kernel_path"] + columns
df = df[columns]
df.set_index("kernel", inplace=True)
return df
81 changes: 81 additions & 0 deletions benchmarking/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import contextlib
import time
from collections import defaultdict


class Profiler:
_instance = None

def __new__(cls, should_profile: bool = False, benchmark: bool = False):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.profiler = (
torch.profiler.profile(
record_shapes=True,
with_flops=True,
profile_memory=True,
with_stack=True,
with_modules=True,
)
if should_profile
else None
)
cls._instance.benchmark = benchmark
cls._instance.benchmark_vals = defaultdict(list)
cls._instance.function_stack = []

return cls._instance

@classmethod
def reset(cls):
cls._instance = None

@classmethod
def profiling_decorator(
cls,
record_name: str = None,
skip_profiling: bool = False,
skip_benchmark: bool = False,
):
def inner(func):
def wrapper(*args, **kwargs):
if not cls._instance or (skip_profiling and skip_benchmark):
return func(*args, **kwargs)
cls._instance.function_stack.append(record_name or func.__name__)
name = ".".join(cls._instance.function_stack)
if cls._instance.profiler and not skip_profiling:
cls._instance.profiler.start()
start_time = time.perf_counter()

with torch.profiler.record_function(name):
result = func(*args, **kwargs)

end_time = time.perf_counter()
if cls._instance.benchmark and not skip_benchmark:
cls._instance.benchmark_vals[name].append(end_time - start_time)
if cls._instance.profiler and not skip_profiling:
cls._instance.profiler.stop()
cls._instance.function_stack.pop()
return result

return wrapper

return inner

@classmethod
def step(cls):
if cls._instance and cls._instance.profiler:
cls._instance.profiler.step()

@classmethod
def get_benchmark_vals(cls):
if cls._instance and cls._instance.benchmark:
return {k: sum(v) / len(v) for k, v in cls._instance.benchmark_vals.items()}
return None

@classmethod
def get_profiling_data(cls):
if cls._instance and cls._instance.profiler:
return self.profiler.key_averages()
return None
10 changes: 9 additions & 1 deletion kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@
from .flash_attention import attention
from .matmul import _matmul, get_higher_dtype, matmul

__all__ = ["blocksparse", "_cross_entropy", "cross_entropy", "_matmul", "matmul", "attention", "get_higher_dtype"]
__all__ = [
"blocksparse",
"_cross_entropy",
"cross_entropy",
"_matmul",
"matmul",
"attention",
"get_higher_dtype",
]
Loading