Skip to content

Commit

Permalink
reformatted.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Jul 2, 2024
1 parent 44c0597 commit 887bcd7
Showing 1 changed file with 5 additions and 18 deletions.
23 changes: 5 additions & 18 deletions src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,11 @@ def uint_to_uniform_float(x):

@triton.heuristics(
values={
# "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048,
"BLOCK": lambda args: 512 if args["N"] <= 512 else 1024,
"num_warps": lambda args: 4
if args["N"] <= 512
else 8
if args["N"] <= 1024
else 16,
"num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16, # fmt: skip
}
)
@triton.jit
@triton.jit(do_not_specialize=("philox_seed", "philox_offset"))
def dropout_forward_kernel(X, Y, N, p, philox_seed, philox_offset, BLOCK: tl.constexpr):
UNROLL = 4
philox_seed = philox_seed.to(tl.int64)
Expand Down Expand Up @@ -91,19 +86,11 @@ def dropout_forward_kernel(X, Y, N, p, philox_seed, philox_offset, BLOCK: tl.con

@triton.heuristics(
values={
"BLOCK": lambda args: 512
if args["N"] <= 512
else 1024
if args["N"] <= 1024
else 2048,
"num_warps": lambda args: 4
if args["N"] <= 512
else 8
if args["N"] <= 1024
else 16,
"BLOCK": lambda args: 512 if args["N"] <= 512 else 1024,
"num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16, # fmt: skip
}
)
@triton.jit
@triton.jit(do_not_specialize=("philox_seed", "philox_offset"))
def dropout_backward_kernel(
DY,
DX,
Expand Down

0 comments on commit 887bcd7

Please sign in to comment.