From 887bcd715c6959fe27796d97816149f8e7b14187 Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Tue, 2 Jul 2024 11:47:09 +0800 Subject: [PATCH] reformatted. --- src/flag_gems/ops/dropout.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index 7d9cd0a6..3aba15c5 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -37,16 +37,11 @@ def uint_to_uniform_float(x): @triton.heuristics( values={ - # "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048, "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024, - "num_warps": lambda args: 4 - if args["N"] <= 512 - else 8 - if args["N"] <= 1024 - else 16, + "num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16, # fmt: skip } ) -@triton.jit +@triton.jit(do_not_specialize=("philox_seed", "philox_offset")) def dropout_forward_kernel(X, Y, N, p, philox_seed, philox_offset, BLOCK: tl.constexpr): UNROLL = 4 philox_seed = philox_seed.to(tl.int64) @@ -91,19 +86,11 @@ def dropout_forward_kernel(X, Y, N, p, philox_seed, philox_offset, BLOCK: tl.con @triton.heuristics( values={ - "BLOCK": lambda args: 512 - if args["N"] <= 512 - else 1024 - if args["N"] <= 1024 - else 2048, - "num_warps": lambda args: 4 - if args["N"] <= 512 - else 8 - if args["N"] <= 1024 - else 16, + "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024, + "num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16, # fmt: skip } ) -@triton.jit +@triton.jit(do_not_specialize=("philox_seed", "philox_offset")) def dropout_backward_kernel( DY, DX,