From 204e8443d2d7fe24440e9b6bd447ca538e730da3 Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Thu, 27 Jun 2024 14:48:33 +0800 Subject: [PATCH] add promotion_methods --- src/flag_gems/utils/pointwise_dynamic.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index e6493943..2ef9885c 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -317,7 +317,6 @@ def generate_functional_pointwise_wrapper( num_output_tensor_index = 0 for i in range(op_desc.num_outputs()): type_promotion_args = ith_parameter_for_type_promotion(op_desc, i) - print(type_promotion_args) k_type_promotion = op_desc.ith_type_promotion_kind(i) code.writeline( ( @@ -711,7 +710,11 @@ def decorator(fn): if __name__ == "__main__": - @pointwise_dynamic(is_tensor=[True, False, True], dtypes=[None, float, None]) + @pointwise_dynamic( + is_tensor=[True, False, True], + dtypes=[None, float, None], + promotion_methods=[[0, 1, 2, "DEFAULT"]], + ) @triton.jit def saxpy(x, alpha, y): return x * alpha + y @@ -725,7 +728,9 @@ def saxpy(x, alpha, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(is_tensor=[True, False, True]) + @pointwise_dynamic( + is_tensor=[True, False, True], promotion_methods=[[0, 1, 2, "DEFAULT"]] + ) @triton.jit def saxpy(x, alpha, y): return x * alpha + y @@ -737,7 +742,7 @@ def saxpy(x, alpha, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(output_dtypes=[torch.bool]) + @pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def ge(x, y): return x > y @@ -749,7 +754,7 @@ def ge(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic() + @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def ordinary(x, y): return tl.sin(x) + tl.cos(y) @@ -761,7 +766,7 @@ def ordinary(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic + @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def ordinary2(x, y): return tl.sin(x) + tl.cos(y) @@ -773,7 +778,7 @@ def ordinary2(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic + @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def ordinary2(x, y): return tl.sin(x) + tl.cos(y) @@ -787,7 +792,9 @@ def ordinary2(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) + @pointwise_dynamic( + is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]] + ) @triton.jit def eq(x, y): return x.to(tl.float32) == y.to(