From 44c05978bd5e40d4d98cc8b36ea2cf237c5b9d7f Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Mon, 1 Jul 2024 13:43:56 +0800 Subject: [PATCH] Reformatted. --- src/flag_gems/ops/dropout.py | 88 +++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index 6f1dfab2..7d9cd0a6 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -1,14 +1,11 @@ -import math import logging import torch import triton import triton.language as tl -from ..utils import libentry from ..utils.random_utils import philox_cuda_seed_offset - try: uint_to_uniform_float = tl.uint_to_uniform_float except AttributeError: @@ -26,37 +23,36 @@ def uint_to_uniform_float(x): x = x.to(tl.int32, bitcast=True) scale = 4.6566127342e-10 else: - tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + tl.static_assert( + tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64) + ) x = x.to(tl.int64, bitcast=True) scale = 1.0842020432385337e-19 x = tl.where(x < 0, -x - 1, x) return x * scale -UNROLL=4 +UNROLL = 4 + @triton.heuristics( values={ # "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048, "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024, - "num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16, + "num_warps": lambda args: 4 + if args["N"] <= 512 + else 8 + if args["N"] <= 1024 + else 16, } ) @triton.jit -def dropout_forward_kernel( - X, - Y, - N, - p, - philox_seed, - philox_offset, - BLOCK: tl.constexpr -): +def dropout_forward_kernel(X, Y, N, p, philox_seed, philox_offset, BLOCK: tl.constexpr): UNROLL = 4 philox_seed = philox_seed.to(tl.int64) philox_offset = philox_offset.to(tl.int64) - c0 = (philox_offset & 0xffffffff).to(tl.uint32) - c1 = ((philox_offset >> 32) & 0xffffffff).to(tl.uint32) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) c0 += i4 _O = c0 * 0 @@ -77,26 +73,34 @@ def dropout_forward_kernel( off_2 = off_1 + BLOCK off_3 = off_2 + BLOCK - x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy='evict_first') - x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy='evict_first') - x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy='evict_first') - x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy='evict_first') + x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy="evict_first") + x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy="evict_first") + x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy="evict_first") + x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy="evict_first") - y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0) - y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0) - y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0) - y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0) + y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0) + y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0) + y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0) + y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0) - tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy='evict_first') - tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy='evict_first') - tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy='evict_first') - tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy='evict_first') + tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") + tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") + tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") + tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") @triton.heuristics( values={ - "BLOCK": lambda args: 512 if args["N"] <= 512 else 1024 if args["N"] <= 1024 else 2048, - "num_warps": lambda args: 4 if args["N"] <= 512 else 8 if args["N"] <= 1024 else 16, + "BLOCK": lambda args: 512 + if args["N"] <= 512 + else 1024 + if args["N"] <= 1024 + else 2048, + "num_warps": lambda args: 4 + if args["N"] <= 512 + else 8 + if args["N"] <= 1024 + else 16, } ) @triton.jit @@ -112,8 +116,8 @@ def dropout_backward_kernel( UNROLL = 4 philox_seed = philox_seed.to(tl.int64) philox_offset = philox_offset.to(tl.int64) - c0 = (philox_offset & 0xffffffff).to(tl.uint32) - c1 = ((philox_offset >> 32) & 0xffffffff).to(tl.uint32) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) c0 += i4 _O = c0 * 0 @@ -132,21 +136,21 @@ def dropout_backward_kernel( off_2 = off_1 + BLOCK off_3 = off_2 + BLOCK - dy_0 = tl.load(DY + off_0, mask=off_0 < N, other=0.0, eviction_policy='evict_first') - dy_1 = tl.load(DY + off_1, mask=off_1 < N, other=0.0, eviction_policy='evict_first') - dy_2 = tl.load(DY + off_2, mask=off_2 < N, other=0.0, eviction_policy='evict_first') - dy_3 = tl.load(DY + off_3, mask=off_3 < N, other=0.0, eviction_policy='evict_first') + dy_0 = tl.load(DY + off_0, mask=off_0 < N, other=0.0, eviction_policy="evict_first") + dy_1 = tl.load(DY + off_1, mask=off_1 < N, other=0.0, eviction_policy="evict_first") + dy_2 = tl.load(DY + off_2, mask=off_2 < N, other=0.0, eviction_policy="evict_first") + dy_3 = tl.load(DY + off_3, mask=off_3 < N, other=0.0, eviction_policy="evict_first") p = 1.0 / (1.0 - p) dx_0 = p * dy_0 * mask0 dx_1 = p * dy_1 * mask1 dx_2 = p * dy_2 * mask2 - dx_3 = p * dy_3 * mask3 + dx_3 = p * dy_3 * mask3 - tl.store(DX + off_0, dx_0, mask=off_0 < N, eviction_policy='evict_first') - tl.store(DX + off_1, dx_1, mask=off_1 < N, eviction_policy='evict_first') - tl.store(DX + off_2, dx_2, mask=off_2 < N, eviction_policy='evict_first') - tl.store(DX + off_3, dx_3, mask=off_3 < N, eviction_policy='evict_first') + tl.store(DX + off_0, dx_0, mask=off_0 < N, eviction_policy="evict_first") + tl.store(DX + off_1, dx_1, mask=off_1 < N, eviction_policy="evict_first") + tl.store(DX + off_2, dx_2, mask=off_2 < N, eviction_policy="evict_first") + tl.store(DX + off_3, dx_3, mask=off_3 < N, eviction_policy="evict_first") class NativeDropout(torch.autograd.Function):