-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding benchmarking Co-authored-by: amainz <[email protected]> Co-authored-by: adamomainz <[email protected]>
- Loading branch information
1 parent
ac74d2f
commit 792e99f
Showing
26 changed files
with
2,612 additions
and
622 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,5 @@ | |
__pycache__/ | ||
*.py[cod] | ||
.pytest_cache | ||
**/.cache | ||
**/meta-llama/**/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .profiler import Profiler | ||
from .benchmark_utils import compare_benchmarks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Any, Dict | ||
import pandas as pd | ||
|
||
|
||
def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: | ||
series_dict = {k: pd.Series(v.values()) for k, v in benchmarks.items()} | ||
series_dict["kernel_path"] = pd.Series( | ||
benchmarks[list(benchmarks.keys())[0]].keys() | ||
) | ||
series_dict["kernel"] = pd.Series( | ||
[k.split(".")[-1] for k in series_dict["kernel_path"]] | ||
) | ||
df = pd.DataFrame() | ||
|
||
for k, v in series_dict.items(): | ||
df[k] = v | ||
columns = [c for c in df.columns if not "kernel" in c] | ||
for i in range(len(columns)): | ||
for j in range(i + 1, len(columns)): | ||
# calculate the difference between the two columns | ||
diff_col_name = f"{columns[i]}-{columns[j]}" | ||
df[diff_col_name] = df[columns[i]] - df[columns[j]] | ||
df.sort_values(by="kernel_path", inplace=True) | ||
columns = [c for c in df.columns if not "kernel" in c] | ||
columns = ["kernel", "kernel_path"] + columns | ||
df = df[columns] | ||
df.set_index("kernel", inplace=True) | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import torch | ||
import contextlib | ||
import time | ||
from collections import defaultdict | ||
|
||
|
||
class Profiler: | ||
_instance = None | ||
|
||
def __new__(cls, should_profile: bool = False, benchmark: bool = False): | ||
if cls._instance is None: | ||
cls._instance = super().__new__(cls) | ||
cls._instance.profiler = ( | ||
torch.profiler.profile( | ||
record_shapes=True, | ||
with_flops=True, | ||
profile_memory=True, | ||
with_stack=True, | ||
with_modules=True, | ||
) | ||
if should_profile | ||
else None | ||
) | ||
cls._instance.benchmark = benchmark | ||
cls._instance.benchmark_vals = defaultdict(list) | ||
cls._instance.function_stack = [] | ||
|
||
return cls._instance | ||
|
||
@classmethod | ||
def reset(cls): | ||
cls._instance = None | ||
|
||
@classmethod | ||
def profiling_decorator( | ||
cls, | ||
record_name: str = None, | ||
skip_profiling: bool = False, | ||
skip_benchmark: bool = False, | ||
): | ||
def inner(func): | ||
def wrapper(*args, **kwargs): | ||
if not cls._instance or (skip_profiling and skip_benchmark): | ||
return func(*args, **kwargs) | ||
cls._instance.function_stack.append(record_name or func.__name__) | ||
name = ".".join(cls._instance.function_stack) | ||
if cls._instance.profiler and not skip_profiling: | ||
cls._instance.profiler.start() | ||
start_time = time.perf_counter() | ||
|
||
with torch.profiler.record_function(name): | ||
result = func(*args, **kwargs) | ||
|
||
end_time = time.perf_counter() | ||
if cls._instance.benchmark and not skip_benchmark: | ||
cls._instance.benchmark_vals[name].append(end_time - start_time) | ||
if cls._instance.profiler and not skip_profiling: | ||
cls._instance.profiler.stop() | ||
cls._instance.function_stack.pop() | ||
return result | ||
|
||
return wrapper | ||
|
||
return inner | ||
|
||
@classmethod | ||
def step(cls): | ||
if cls._instance and cls._instance.profiler: | ||
cls._instance.profiler.step() | ||
|
||
@classmethod | ||
def get_benchmark_vals(cls): | ||
if cls._instance and cls._instance.benchmark: | ||
return {k: sum(v) / len(v) for k, v in cls._instance.benchmark_vals.items()} | ||
return None | ||
|
||
@classmethod | ||
def get_profiling_data(cls): | ||
if cls._instance and cls._instance.profiler: | ||
return self.profiler.key_averages() | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.