Skip to content

Commit

Permalink
Definesuint_to_uniform_float in dropout local file if import fails.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Jul 1, 2024
1 parent 963ca9e commit 0722a99
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,37 @@
from ..utils import libentry
from ..utils.random_utils import philox_cuda_seed_offset


try:
uint_to_uniform_float = tl.uint_to_uniform_float
except AttributeError:
# Copied from triton.language package for compatibility
@triton.jit
def uint_to_uniform_float(x):
"""
Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
"""
# TODO: fix frontend issues and cleanup
# conditions can be simplified
# scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
x = x.to(tl.int32, bitcast=True)
scale = 4.6566127342e-10
else:
tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64))
x = x.to(tl.int64, bitcast=True)
scale = 1.0842020432385337e-19
x = tl.where(x < 0, -x - 1, x)
return x * scale


UNROLL=4

@triton.heuristics(
values={
"BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048,
# "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048,
"BLOCK": lambda args: 512 if args["N"] <= 512 else 1024,
"num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16,
}
)
Expand All @@ -35,10 +61,10 @@ def dropout_forward_kernel(
c0 += i4
_O = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
r0 = tl.uint_to_uniform_float(r0)
r1 = tl.uint_to_uniform_float(r1)
r2 = tl.uint_to_uniform_float(r2)
r3 = tl.uint_to_uniform_float(r3)
r0 = uint_to_uniform_float(r0)
r1 = uint_to_uniform_float(r1)
r2 = uint_to_uniform_float(r2)
r3 = uint_to_uniform_float(r3)

mask0 = r0 > p
mask1 = r1 > p
Expand Down Expand Up @@ -92,10 +118,10 @@ def dropout_backward_kernel(
c0 += i4
_O = c0 * 0
r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
r0 = tl.uint_to_uniform_float(r0)
r1 = tl.uint_to_uniform_float(r1)
r2 = tl.uint_to_uniform_float(r2)
r3 = tl.uint_to_uniform_float(r3)
r0 = uint_to_uniform_float(r0)
r1 = uint_to_uniform_float(r1)
r2 = uint_to_uniform_float(r2)
r3 = uint_to_uniform_float(r3)

mask0 = r0 > p
mask1 = r1 > p
Expand Down

0 comments on commit 0722a99

Please sign in to comment.