Skip to content

Commit

Permalink
[test] move specific arg-gen functions into corresponding tests (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon authored Jun 12, 2024
1 parent c84b6a5 commit 19fb335
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 105 deletions.
113 changes: 9 additions & 104 deletions benchmark/performance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def __init__(self, op_name, torch_op, arg_func, dtype, batch, sizes):
self.dtype = dtype
self.batch = batch
self.sizes = sizes
self.gems_op = None

def set_gems(self, gems_op):
self.gems_op = gems_op

def profile(self, op, *args):
if CPU_MODE:
Expand All @@ -43,8 +47,11 @@ def run(self):
for size in self.sizes:
args = self.arg_func(self.dtype, self.batch, size)
torch_perf = self.profile(self.torch_op, *args)
with flag_gems.use_gems():
gems_perf = self.profile(self.torch_op, *args)
if self.gems_op:
gems_perf = self.profile(self.gems_op, *args)
else:
with flag_gems.use_gems():
gems_perf = self.profile(self.torch_op, *args)
print(f"{size: <10}{torch_perf: >20.6}{gems_perf: >20.6}")


Expand Down Expand Up @@ -92,105 +99,3 @@ def ternary_args(dtype, batch, size):
inp2 = torch.randn([batch, size], dtype=dtype, device="cuda")
inp3 = torch.randn([batch, size], dtype=dtype, device="cuda")
return inp1, inp2, inp3


def cross_entropy_loss_args(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
target = torch.randint(
0,
size,
[
batch,
],
device="cuda",
)
return inp, target


def cumsum_args(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
return inp, 1


def group_norm_args(dtype, batch, size):
C = 16
G = 16
inp = torch.randn([batch, C, size], dtype=dtype, device="cuda")
weight = torch.randn(
[
C,
],
dtype=dtype,
device="cuda",
)
bias = torch.randn(
[
C,
],
dtype=dtype,
device="cuda",
)
return inp, G, weight, bias


def layer_norm_args(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
weight = torch.randn(
[
size,
],
dtype=dtype,
device="cuda",
)
bias = torch.randn(
[
size,
],
dtype=dtype,
device="cuda",
)
return (
inp,
[
size,
],
weight,
bias,
)


def addmm_args(dtype, batch, size):
bias = torch.randn(
[
size,
],
dtype=dtype,
device="cuda",
)
inp1 = torch.randn([size, size], dtype=dtype, device="cuda")
inp2 = torch.randn([size, size], dtype=dtype, device="cuda")
return bias, inp1, inp2


def bmm_args(dtype, batch, size):
inp1 = torch.randn([batch, size, size], dtype=dtype, device="cuda")
inp2 = torch.randn([batch, size, size], dtype=dtype, device="cuda")
return inp1, inp2


def mm_args(dtype, batch, size):
inp1 = torch.randn([size, size], dtype=dtype, device="cuda")
inp2 = torch.randn([size, size], dtype=dtype, device="cuda")
return inp1, inp2


def mv_args(dtype, batch, size):
inp1 = torch.randn([size, size], dtype=dtype, device="cuda")
inp2 = torch.randn([size], dtype=dtype, device="cuda")
return inp1, inp2


def outer_args(dtype, batch, size):
inp1 = torch.randn([size], dtype=dtype, device="cuda")
inp2 = torch.randn([size], dtype=dtype, device="cuda")
return inp1, inp2
37 changes: 37 additions & 0 deletions benchmark/test_blas_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@

@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_addmm(dtype):

def addmm_args(dtype, batch, size):
bias = torch.randn(
[
size,
],
dtype=dtype,
device="cuda",
)
inp1 = torch.randn([size, size], dtype=dtype, device="cuda")
inp2 = torch.randn([size, size], dtype=dtype, device="cuda")
return bias, inp1, inp2

bench = Benchmark(
op_name="addmm",
torch_op=torch.addmm,
Expand All @@ -19,6 +32,12 @@ def test_perf_addmm(dtype):

@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_bmm(dtype):

def bmm_args(dtype, batch, size):
inp1 = torch.randn([batch, size, size], dtype=dtype, device="cuda")
inp2 = torch.randn([batch, size, size], dtype=dtype, device="cuda")
return inp1, inp2

bench = Benchmark(
op_name="bmm",
torch_op=torch.bmm,
Expand All @@ -32,6 +51,12 @@ def test_perf_bmm(dtype):

@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_mm(dtype):

def mm_args(dtype, batch, size):
inp1 = torch.randn([size, size], dtype=dtype, device="cuda")
inp2 = torch.randn([size, size], dtype=dtype, device="cuda")
return inp1, inp2

bench = Benchmark(
op_name="mm",
torch_op=torch.mm,
Expand All @@ -45,6 +70,12 @@ def test_perf_mm(dtype):

@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_mv(dtype):

def mv_args(dtype, batch, size):
inp1 = torch.randn([size, size], dtype=dtype, device="cuda")
inp2 = torch.randn([size], dtype=dtype, device="cuda")
return inp1, inp2

bench = Benchmark(
op_name="mv",
torch_op=torch.mv,
Expand All @@ -58,6 +89,12 @@ def test_perf_mv(dtype):

@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_outer(dtype):

def outer_args(dtype, batch, size):
inp1 = torch.randn([size], dtype=dtype, device="cuda")
inp2 = torch.randn([size], dtype=dtype, device="cuda")
return inp1, inp2

bench = Benchmark(
op_name="outer",
torch_op=torch.outer,
Expand Down
134 changes: 134 additions & 0 deletions benchmark/test_fused_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch
import pytest
import flag_gems
from .performance_utils import *


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_gelu_and_mul(dtype):

def torch_op(x, y):
return torch.mul(torch.nn.functional.gelu(x), y)

gems_op = flag_gems.gelu_and_mul

bench = Benchmark(
op_name="gelu_and_mul",
torch_op=torch_op,
arg_func=binary_args,
dtype=dtype,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.set_gems(gems_op)
bench.run()


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_silu_and_mul(dtype):

def torch_op(x, y):
return torch.mul(torch.nn.functional.silu(x), y)

gems_op = flag_gems.silu_and_mul

bench = Benchmark(
op_name="silu_and_mul",
torch_op=torch_op,
arg_func=binary_args,
dtype=dtype,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.set_gems(gems_op)
bench.run()


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_skip_layernorm(dtype):

def skip_layernorm_args(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
residual = torch.randn([batch, size], dtype=dtype, device="cuda")
weight = torch.randn(
[
size,
],
dtype=dtype,
device="cuda",
)
bias = torch.randn(
[
size,
],
dtype=dtype,
device="cuda",
)
return (
inp,
residual,
[
size,
],
weight,
bias,
)

def torch_op(inp, residual, layer_shape, weight, bias):
return torch.layer_norm(inp + residual, layer_shape, weight, bias)

gems_op = flag_gems.skip_layer_norm

bench = Benchmark(
op_name="skip_layernorm",
torch_op=torch_op,
arg_func=skip_layernorm_args,
dtype=dtype,
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.set_gems(gems_op)
bench.run()


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_skip_rmsnorm(dtype):

def skip_rmsnorm_args(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
residual = torch.randn([batch, size], dtype=dtype, device="cuda")
weight = torch.randn(
[
size,
],
dtype=dtype,
device="cuda",
)
return (
inp,
residual,
[
size,
],
weight,
1e-5,
)

def torch_op(x, residual, layer_shape, weight, eps):
x = x + residual
variance = x.pow(2).mean(-1, keepdim=True)
hidden_states = x * torch.rsqrt(variance + eps)
return weight * hidden_states

gems_op = flag_gems.skip_rms_norm

bench = Benchmark(
op_name="skip_rmsnorm",
torch_op=torch_op,
arg_func=skip_rmsnorm_args,
dtype=dtype,
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.set_gems(gems_op)
bench.run()
Loading

0 comments on commit 19fb335

Please sign in to comment.