Skip to content

Commit

Permalink
add promotion_methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Bowen12992 committed Jun 27, 2024
1 parent 6c80334 commit 204e844
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/flag_gems/utils/pointwise_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ def generate_functional_pointwise_wrapper(
num_output_tensor_index = 0
for i in range(op_desc.num_outputs()):
type_promotion_args = ith_parameter_for_type_promotion(op_desc, i)
print(type_promotion_args)
k_type_promotion = op_desc.ith_type_promotion_kind(i)
code.writeline(
(
Expand Down Expand Up @@ -711,7 +710,11 @@ def decorator(fn):

if __name__ == "__main__":

@pointwise_dynamic(is_tensor=[True, False, True], dtypes=[None, float, None])
@pointwise_dynamic(
is_tensor=[True, False, True],
dtypes=[None, float, None],
promotion_methods=[[0, 1, 2, "DEFAULT"]],
)
@triton.jit
def saxpy(x, alpha, y):
return x * alpha + y
Expand All @@ -725,7 +728,9 @@ def saxpy(x, alpha, y):
torch.testing.assert_close(out1, out2)
print()

@pointwise_dynamic(is_tensor=[True, False, True])
@pointwise_dynamic(
is_tensor=[True, False, True], promotion_methods=[[0, 1, 2, "DEFAULT"]]
)
@triton.jit
def saxpy(x, alpha, y):
return x * alpha + y
Expand All @@ -737,7 +742,7 @@ def saxpy(x, alpha, y):
torch.testing.assert_close(out1, out2)
print()

@pointwise_dynamic(output_dtypes=[torch.bool])
@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]])
@triton.jit
def ge(x, y):
return x > y
Expand All @@ -749,7 +754,7 @@ def ge(x, y):
torch.testing.assert_close(out1, out2)
print()

@pointwise_dynamic()
@pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]])
@triton.jit
def ordinary(x, y):
return tl.sin(x) + tl.cos(y)
Expand All @@ -761,7 +766,7 @@ def ordinary(x, y):
torch.testing.assert_close(out1, out2)
print()

@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]])
@triton.jit
def ordinary2(x, y):
return tl.sin(x) + tl.cos(y)
Expand All @@ -773,7 +778,7 @@ def ordinary2(x, y):
torch.testing.assert_close(out1, out2)
print()

@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]])
@triton.jit
def ordinary2(x, y):
return tl.sin(x) + tl.cos(y)
Expand All @@ -787,7 +792,9 @@ def ordinary2(x, y):
torch.testing.assert_close(out1, out2)
print()

@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool])
@pointwise_dynamic(
is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]
)
@triton.jit
def eq(x, y):
return x.to(tl.float32) == y.to(
Expand Down

0 comments on commit 204e844

Please sign in to comment.