Skip to content

Commit

Permalink
Improve dropout performance by unrolling 4xint random masks.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Jun 27, 2024
1 parent 9168f2d commit 963ca9e
Showing 1 changed file with 93 additions and 79 deletions.
172 changes: 93 additions & 79 deletions src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import logging

import torch
Expand All @@ -7,39 +8,13 @@
from ..utils import libentry
from ..utils.random_utils import philox_cuda_seed_offset

try:
tl_rand_dtype = tl.int64

@triton.jit
def _rand(seed, offset):
offset = offset.to(tl_rand_dtype)

_grid = (1,)
_seed, _offset = philox_cuda_seed_offset(0)
_rand[_grid](_seed, _offset)
except Exception:
tl_rand_dtype = tl.int32

del _grid
del _seed
del _offset


@libentry()
@triton.autotune(
configs=[
triton.Config({"N_BLOCK_SIZE": 256}, num_warps=2, num_stages=4),
triton.Config({"N_BLOCK_SIZE": 256}, num_warps=2, num_stages=5),
triton.Config({"N_BLOCK_SIZE": 512}, num_warps=2, num_stages=4),
triton.Config({"N_BLOCK_SIZE": 512}, num_warps=2, num_stages=5),
triton.Config({"N_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
triton.Config({"N_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5),
triton.Config({"N_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4),
triton.Config({"N_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5),
],
key=[
"N",
],
UNROLL=4

@triton.heuristics(
values={
"BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048,
"num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16,
}
)
@triton.jit
def dropout_forward_kernel(
Expand All @@ -49,38 +24,54 @@ def dropout_forward_kernel(
p,
philox_seed,
philox_offset,
N_BLOCK_SIZE: tl.constexpr,
BLOCK: tl.constexpr
):
UNROLL = 4
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl_rand_dtype)
pid = tl.program_id(0) * N_BLOCK_SIZE
offset = pid + tl.arange(0, N_BLOCK_SIZE)
mask = offset < N
X_ptr = X + offset
Y_ptr = Y + offset
inp = tl.load(X_ptr, mask=mask, other=0.0)
philox_offset = philox_offset + offset
pmask = tl.rand(philox_seed, philox_offset, n_rounds=6) > p
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xffffffff).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xffffffff).to(tl.uint32)
i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
c0 += i4
_O = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
r0 = tl.uint_to_uniform_float(r0)
r1 = tl.uint_to_uniform_float(r1)
r2 = tl.uint_to_uniform_float(r2)
r3 = tl.uint_to_uniform_float(r3)

mask0 = r0 > p
mask1 = r1 > p
mask2 = r2 > p
mask3 = r3 > p
p = 1.0 / (1.0 - p)
out = tl.where(pmask, inp * p, 0.0)
tl.store(Y_ptr, out.to(inp.dtype), mask=mask)


@libentry()
@triton.autotune(
[
triton.Config({"N_BLOCK_SIZE": 256}, num_warps=2, num_stages=4),
triton.Config({"N_BLOCK_SIZE": 256}, num_warps=2, num_stages=5),
triton.Config({"N_BLOCK_SIZE": 512}, num_warps=2, num_stages=4),
triton.Config({"N_BLOCK_SIZE": 512}, num_warps=2, num_stages=5),
triton.Config({"N_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
triton.Config({"N_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5),
triton.Config({"N_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4),
triton.Config({"N_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5),
],
key=[
"N",
],

off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)
off_1 = off_0 + BLOCK
off_2 = off_1 + BLOCK
off_3 = off_2 + BLOCK

x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy='evict_first')
x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy='evict_first')
x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy='evict_first')
x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy='evict_first')

y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0)
y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0)
y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0)
y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0)

tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy='evict_first')
tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy='evict_first')
tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy='evict_first')
tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy='evict_first')


@triton.heuristics(
values={
"BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048,
"num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16,
}
)
@triton.jit
def dropout_backward_kernel(
Expand All @@ -90,23 +81,46 @@ def dropout_backward_kernel(
p,
philox_seed,
philox_offset,
N_BLOCK_SIZE: tl.constexpr,
BLOCK: tl.constexpr,
):
UNROLL = 4
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl_rand_dtype)
pid = tl.program_id(0) * N_BLOCK_SIZE
offset = pid + tl.arange(0, N_BLOCK_SIZE)
mask = offset < N
DY_ptr = DY + offset
DX_ptr = DX + offset
philox_offset = philox_offset + offset
pmask = tl.rand(philox_seed, philox_offset, n_rounds=6) > p
dy = tl.load(DY_ptr, mask=mask, other=0.0)

output = dy * pmask
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xffffffff).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xffffffff).to(tl.uint32)
i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
c0 += i4
_O = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
r0 = tl.uint_to_uniform_float(r0)
r1 = tl.uint_to_uniform_float(r1)
r2 = tl.uint_to_uniform_float(r2)
r3 = tl.uint_to_uniform_float(r3)

mask0 = r0 > p
mask1 = r1 > p
mask2 = r2 > p
mask3 = r3 > p
off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)
off_1 = off_0 + BLOCK
off_2 = off_1 + BLOCK
off_3 = off_2 + BLOCK

dy_0 = tl.load(DY + off_0, mask=off_0 < N, other=0.0, eviction_policy='evict_first')
dy_1 = tl.load(DY + off_1, mask=off_1 < N, other=0.0, eviction_policy='evict_first')
dy_2 = tl.load(DY + off_2, mask=off_2 < N, other=0.0, eviction_policy='evict_first')
dy_3 = tl.load(DY + off_3, mask=off_3 < N, other=0.0, eviction_policy='evict_first')

p = 1.0 / (1.0 - p)
output *= p
tl.store(DX_ptr, output.to(dy.dtype), mask=mask)
dx_0 = p * dy_0 * mask0
dx_1 = p * dy_1 * mask1
dx_2 = p * dy_2 * mask2
dx_3 = p * dy_3 * mask3

tl.store(DX + off_0, dx_0, mask=off_0 < N, eviction_policy='evict_first')
tl.store(DX + off_1, dx_1, mask=off_1 < N, eviction_policy='evict_first')
tl.store(DX + off_2, dx_2, mask=off_2 < N, eviction_policy='evict_first')
tl.store(DX + off_3, dx_3, mask=off_3 < N, eviction_policy='evict_first')


class NativeDropout(torch.autograd.Function):
Expand All @@ -117,10 +131,10 @@ def forward(ctx, x, p, train):
x = x.contiguous()
out = torch.empty_like(x)
N = x.numel()
grid_fn = lambda meta: (triton.cdiv(N, meta["N_BLOCK_SIZE"]),)
grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
# (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
# hence we cannot obtain the per thread offset as in Pytorch.
increment = N
increment = triton.cdiv(N, UNROLL)
philox_seed, philox_offset = philox_cuda_seed_offset(increment)
dropout_forward_kernel[grid_fn](x, out, N, p, philox_seed, philox_offset)
ctx.p = p
Expand All @@ -134,7 +148,7 @@ def backward(ctx, grad_outputs, kwargs):
grad_outputs = grad_outputs.contiguous()
grad_inputs = torch.empty_like(grad_outputs)
N = grad_outputs.numel()
grid_fn = lambda meta: (triton.cdiv(N, meta["N_BLOCK_SIZE"]),)
grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
dropout_backward_kernel[grid_fn](
grad_outputs, grad_inputs, N, ctx.p, ctx.philox_seed, ctx.philox_offset
)
Expand Down

0 comments on commit 963ca9e

Please sign in to comment.