Skip to content

Commit

Permalink
Reformatted.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Jul 1, 2024
1 parent 0722a99 commit 44c0597
Showing 1 changed file with 46 additions and 42 deletions.
88 changes: 46 additions & 42 deletions src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import math
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry
from ..utils.random_utils import philox_cuda_seed_offset


try:
uint_to_uniform_float = tl.uint_to_uniform_float
except AttributeError:
Expand All @@ -26,37 +23,36 @@ def uint_to_uniform_float(x):
x = x.to(tl.int32, bitcast=True)
scale = 4.6566127342e-10
else:
tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64))
tl.static_assert(
tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)
)
x = x.to(tl.int64, bitcast=True)
scale = 1.0842020432385337e-19
x = tl.where(x < 0, -x - 1, x)
return x * scale


UNROLL=4
UNROLL = 4


@triton.heuristics(
values={
# "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048,
"BLOCK": lambda args: 512 if args["N"] <= 512 else 1024,
"num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16,
"num_warps": lambda args: 4
if args["N"] <= 512
else 8
if args["N"] <= 1024
else 16,
}
)
@triton.jit
def dropout_forward_kernel(
X,
Y,
N,
p,
philox_seed,
philox_offset,
BLOCK: tl.constexpr
):
def dropout_forward_kernel(X, Y, N, p, philox_seed, philox_offset, BLOCK: tl.constexpr):
UNROLL = 4
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xffffffff).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xffffffff).to(tl.uint32)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
c0 += i4
_O = c0 * 0
Expand All @@ -77,26 +73,34 @@ def dropout_forward_kernel(
off_2 = off_1 + BLOCK
off_3 = off_2 + BLOCK

x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy='evict_first')
x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy='evict_first')
x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy='evict_first')
x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy='evict_first')
x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy="evict_first")
x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy="evict_first")
x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy="evict_first")
x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy="evict_first")

y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0)
y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0)
y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0)
y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0)
y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0)
y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0)
y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0)
y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0)

tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy='evict_first')
tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy='evict_first')
tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy='evict_first')
tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy='evict_first')
tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")


@triton.heuristics(
values={
"BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048,
"num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16,
"BLOCK": lambda args: 512
if args["N"] <= 512
else 1024
if args["N"] <= 1024
else 2048,
"num_warps": lambda args: 4
if args["N"] <= 512
else 8
if args["N"] <= 1024
else 16,
}
)
@triton.jit
Expand All @@ -112,8 +116,8 @@ def dropout_backward_kernel(
UNROLL = 4
philox_seed = philox_seed.to(tl.int64)
philox_offset = philox_offset.to(tl.int64)
c0 = (philox_offset & 0xffffffff).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xffffffff).to(tl.uint32)
c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
c0 += i4
_O = c0 * 0
Expand All @@ -132,21 +136,21 @@ def dropout_backward_kernel(
off_2 = off_1 + BLOCK
off_3 = off_2 + BLOCK

dy_0 = tl.load(DY + off_0, mask=off_0 < N, other=0.0, eviction_policy='evict_first')
dy_1 = tl.load(DY + off_1, mask=off_1 < N, other=0.0, eviction_policy='evict_first')
dy_2 = tl.load(DY + off_2, mask=off_2 < N, other=0.0, eviction_policy='evict_first')
dy_3 = tl.load(DY + off_3, mask=off_3 < N, other=0.0, eviction_policy='evict_first')
dy_0 = tl.load(DY + off_0, mask=off_0 < N, other=0.0, eviction_policy="evict_first")
dy_1 = tl.load(DY + off_1, mask=off_1 < N, other=0.0, eviction_policy="evict_first")
dy_2 = tl.load(DY + off_2, mask=off_2 < N, other=0.0, eviction_policy="evict_first")
dy_3 = tl.load(DY + off_3, mask=off_3 < N, other=0.0, eviction_policy="evict_first")

p = 1.0 / (1.0 - p)
dx_0 = p * dy_0 * mask0
dx_1 = p * dy_1 * mask1
dx_2 = p * dy_2 * mask2
dx_3 = p * dy_3 * mask3
dx_3 = p * dy_3 * mask3

tl.store(DX + off_0, dx_0, mask=off_0 < N, eviction_policy='evict_first')
tl.store(DX + off_1, dx_1, mask=off_1 < N, eviction_policy='evict_first')
tl.store(DX + off_2, dx_2, mask=off_2 < N, eviction_policy='evict_first')
tl.store(DX + off_3, dx_3, mask=off_3 < N, eviction_policy='evict_first')
tl.store(DX + off_0, dx_0, mask=off_0 < N, eviction_policy="evict_first")
tl.store(DX + off_1, dx_1, mask=off_1 < N, eviction_policy="evict_first")
tl.store(DX + off_2, dx_2, mask=off_2 < N, eviction_policy="evict_first")
tl.store(DX + off_3, dx_3, mask=off_3 < N, eviction_policy="evict_first")


class NativeDropout(torch.autograd.Function):
Expand Down

0 comments on commit 44c0597

Please sign in to comment.