Skip to content

Commit

Permalink
[Tests] Fix op name (#219)
Browse files Browse the repository at this point in the history
Fix op name in tests and benchmark.

---------

Co-authored-by: zhengyang <[email protected]>
  • Loading branch information
zhzhcookie and zhengyang authored Sep 19, 2024
1 parent cb2dc9e commit 204f3d4
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion OperatorList.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
- sum
- var_mean
- vector_norm
- cross_entropy_loss
- CrossEntropyLoss
- group_norm
- log_softmax
- native_group_norm
Expand Down
2 changes: 1 addition & 1 deletion benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_perf_div():

def test_perf_floordiv_int():
bench = Benchmark(
op_name="floor_div",
op_name="floor_divide",
torch_op=torch.floor_divide,
arg_func=binary_int_args,
dtypes=INT_DTYPES,
Expand Down
2 changes: 1 addition & 1 deletion benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def cross_entropy_loss_args(dtype, batch, size):
return inp, target

bench = Benchmark(
op_name="cross_entropy_loss",
op_name="CrossEntropyLoss",
torch_op=torch.nn.CrossEntropyLoss(),
arg_func=cross_entropy_loss_args,
dtypes=FLOAT_DTYPES,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_accuracy_div_scalar_tensor(shape, scalar, dtype):
gems_assert_close(res_out, ref_out, dtype, equal_nan=True)


@pytest.mark.div
@pytest.mark.trunc_divide
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", [torch.float32])
# Note : tl.math.div_rz only support float32, cast will cause diff
Expand All @@ -324,7 +324,7 @@ def test_accuracy_trunc_div(shape, dtype):


# TODO: failed at large size, eg. (65536 * 2048,)
@pytest.mark.div
@pytest.mark.floor_divide
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", [torch.float32])
def test_accuracy_floor_div_float(shape, dtype):
Expand Down Expand Up @@ -1044,6 +1044,7 @@ def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan):
REPEAT_INTERLEAVE_DIM = [-1, 0, None]


@pytest.mark.repeat_interleave
@pytest.mark.parametrize("shape", REPEAT_INTERLEAVE_SHAPES + [(1,)])
@pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand All @@ -1058,6 +1059,7 @@ def test_accuracy_repeat_interleave_self_int(shape, dim, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.repeat_interleave
@pytest.mark.parametrize("shape", REPEAT_INTERLEAVE_SHAPES)
@pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_norm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


@pytest.mark.group_norm
@pytest.mark.native_group_norm
@pytest.mark.parametrize(
"N, C, H, W, num_groups",
[
Expand Down Expand Up @@ -75,6 +76,7 @@ def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype):

# TODO: failed at (1, 2) (2~32, 40499) (200, 2~64) (200~4096, 40999)
@pytest.mark.layer_norm
@pytest.mark.native_layer_norm
@pytest.mark.parametrize(
"shape", [(1, 40999)] if QUICK_MODE else [(1, 40999), (4096, 256)]
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
CROSS_ENTROPY_LOSS_REDUCTION = ["sum"] if QUICK_MODE else ["mean", "none", "sum"]


@pytest.mark.amx
@pytest.mark.amax
@pytest.mark.parametrize("keepdim, dim, shape", KEEPDIM_DIMS_SHAPE)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_amax(shape, dim, keepdim, dtype):
Expand Down
1 change: 1 addition & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

# TODO: sometimes failed at (8192,), 0.6, bfloat16
@pytest.mark.dropout
@pytest.mark.native_dropout
@pytest.mark.parametrize("shape", SPECIAL_SHAPES)
@pytest.mark.parametrize("p", [0.3, 0.6, 0.9])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down

0 comments on commit 204f3d4

Please sign in to comment.