Skip to content

Commit

Permalink
optimize layernorm forward (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
iclementine authored Sep 20, 2024
1 parent 204f3d4 commit da86496
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 64 deletions.
275 changes: 212 additions & 63 deletions src/flag_gems/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,85 +8,190 @@
from ..utils import libentry


def cfggen():
block_m = [1, 2, 4]
block_n = [1024, 2048, 4096]
warps = [4, 8, 16]
configs = [
triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w)
for m in block_m
for n in block_n
for w in warps
]
return configs
@triton.jit
def prev_multiple_of(a, b):
# the largest x<a that x%b ==0
return tl.cdiv(a, b) * b - b


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.autotune(
configs=[triton.Config({}, num_warps=w) for w in [4, 8, 16]],
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
def layer_norm_kernel(
X,
Y,
W,
B,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
def layer_norm_persistent_kernel(
in_ptr,
out_ptr,
weight_ptr,
bias_ptr,
out_mean_ptr, # pointer to the mean
out_rstd_ptr, # pointer to the 1/std
M,
N,
eps,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
TILE_N: tl.constexpr,
):
# using 1d tile makes code clean
# Map the program id to the row of X and Y it should compute.
pid = tl.program_id(0)
row = pid * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
row_mask = row < M
X += row * N
Y += row * N

# Compute mean
_mean = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_COL_SIZE):
cols = off + tl.arange(0, BLOCK_COL_SIZE)[None, :]
col_mask = cols < N
mask = row_mask and col_mask
n_offsets = tl.arange(0, TILE_N)
mask = n_offsets < N

a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=1) / N
mean = mean[:, None]
x = tl.load(in_ptr + pid * N + n_offsets, mask, other=0.0).to(tl.float32)
m = tl.sum(x) / N
d = x - m # deviation
s = tl.where(mask, d * d, 0)
sum_square = tl.sum(s) # sum of square of deviation
var = sum_square / N
rstd = tl.math.rsqrt(var + eps)

# Compute variance
_var = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_COL_SIZE):
cols = off + tl.arange(0, BLOCK_COL_SIZE)[None, :]
col_mask = cols < N
mask = row_mask and col_mask
tl.store(out_mean_ptr + pid, m)
tl.store(out_rstd_ptr + pid, rstd)

x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
x = tl.where(col_mask, x - mean, 0.0)
_var += x * x
var = tl.sum(_var, axis=1) / N
var = var[:, None]
rstd = 1 / tl.sqrt(var + eps)
w = tl.load(weight_ptr + n_offsets, mask=mask)
b = tl.load(bias_ptr + n_offsets, mask=mask)
out = (x - m) * rstd * w + b

tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)


@libentry()
@triton.autotune(
configs=[triton.Config({}, num_warps=w) for w in [4, 8, 16]],
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
def layer_norm_persistent_kernel_multiline(
in_ptr,
out_ptr,
weight_ptr,
bias_ptr,
out_mean_ptr, # pointer to the mean
out_rstd_ptr, # pointer to the 1/std
M,
N,
eps,
TILE_M: tl.constexpr,
TILE_N: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
pid = tl.program_id(0)
m_offsets = pid * TILE_M + tl.arange(0, TILE_M)
m_mask = m_offsets < M

n_offsets = tl.arange(0, TILE_N)[None, :]
n_mask = n_offsets < N
mask = m_mask[:, None] & n_mask

x = tl.load(in_ptr + m_offsets[:, None] * N + n_offsets, mask, other=0.0).to(
tl.float32
)
m = tl.sum(x, axis=1) / N
d = x - m[:, None] # deviation
s = tl.where(mask, d * d, 0)
sum_square = tl.sum(s, axis=1) # sum of square of deviation
var = sum_square / N
rstd = tl.math.rsqrt(var + eps)

tl.store(out_mean_ptr + m_offsets, m, mask=m_mask)
tl.store(out_rstd_ptr + m_offsets, rstd, mask=m_mask)

w = tl.load(weight_ptr + n_offsets, mask=n_mask)
b = tl.load(bias_ptr + n_offsets, mask=n_mask)
out = (x - m[:, None]) * rstd[:, None] * w + b

tl.store(out_ptr + m_offsets[:, None] * N + n_offsets, out, mask=mask)


@libentry()
@triton.autotune(
configs=[
triton.Config({"TILE_N": tile_n}, num_warps=w)
for tile_n in [1024, 2048, 4096, 8192]
for w in [4, 8, 16]
],
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
def layer_norm_loop_kernel(
in_ptr,
out_ptr,
weight_ptr,
bias_ptr,
out_mean_ptr, # pointer to the mean
out_rstd_ptr, # pointer to the 1/std
M,
N,
eps,
TILE_N: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
pid = tl.program_id(0)

# Compute mean
m = tl.zeros((TILE_N,), dtype=tl.float32) # mean
s = tl.zeros((TILE_N,), dtype=tl.float32) # sum((x - m)^2)
cnt = tl.zeros((TILE_N,), dtype=tl.int32)
num_steps = tl.cdiv(N, TILE_N)
for step in range(0, num_steps - 1, 1):
start_n = step * TILE_N
n_offsets = start_n + tl.arange(0, TILE_N)
x = tl.load(in_ptr + pid * N + n_offsets).to(tl.float32)
new_m = m + (x - m) / (step + 1)
new_s = s + (x - new_m) * (x - m)
cnt += 1
m = new_m
s = new_s

# the last step
for step in range(num_steps - 1, num_steps, 1):
start_n = step * TILE_N
n_offsets = start_n + tl.arange(0, TILE_N)
mask = n_offsets < N
x = tl.load(in_ptr + pid * N + n_offsets, mask=mask).to(tl.float32)
new_m = tl.where(mask, m + (x - m) / (step + 1), m)
new_s = tl.where(mask, s + (x - new_m) * (x - m), s)
cnt += mask.to(tl.int32)
m = new_m
s = new_s

final_m = tl.sum(m * cnt) / N
var = tl.sum(s + cnt * (m - final_m) * (m - final_m)) / N
rstd = tl.math.rsqrt(var + eps)
m = final_m
# Write mean / rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
tl.store(out_mean_ptr + pid, m)
tl.store(out_rstd_ptr + pid, rstd)

# reverse the order of the second sweep
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_COL_SIZE):
cols = off + tl.arange(0, BLOCK_COL_SIZE)[None, :]
col_mask = cols < N
mask = row_mask and col_mask
prev_multiple = prev_multiple_of(N, TILE_N)
# the first step, masking is needed
for start_n in range(0, TILE_N, TILE_N):
n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
mask = n_offsets < N
x = tl.load(
in_ptr + pid * N + n_offsets,
mask=mask,
other=0.0,
eviction_policy="evict_first",
).to(tl.float32)
w = tl.load(weight_ptr + n_offsets, mask=mask)
b = tl.load(bias_ptr + n_offsets, mask=mask)
out = w * (x - m) * rstd + b
tl.store(out_ptr + pid * N + n_offsets, out, mask=mask)

w = tl.load(W + cols, col_mask)
b = tl.load(B + cols, col_mask)
x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
x = tl.where(col_mask, x - mean, 0.0)
x_hat = x * rstd
y = x_hat * w + b
# Write output
tl.store(Y + cols, y, mask=mask)
for start_n in range(TILE_N, N, TILE_N):
n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
x = tl.load(in_ptr + pid * N + n_offsets, eviction_policy="evict_first").to(
tl.float32
)
w = tl.load(weight_ptr + n_offsets)
b = tl.load(bias_ptr + n_offsets)
out = w * (x - m) * rstd + b
tl.store(out_ptr + pid * N + n_offsets, out)


@libentry()
Expand Down Expand Up @@ -211,16 +316,60 @@ def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True)
# M = math.prod(x.shape[:dim])
N = math.prod(normalized_shape)
M = x.numel() // N

x = x.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
y = torch.empty_like(x)
mean = torch.empty(M, dtype=x.dtype, device=x.device)
rstd = torch.empty(M, dtype=x.dtype, device=x.device)
grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)

with torch.cuda.device(x.device):
layer_norm_kernel[grid](x, y, weight, bias, mean, rstd, M, N, eps)
if N <= 128:
TILE_N = triton.next_power_of_2(N)
TILE_M = triton.cdiv(1024, TILE_N)
grid = (triton.cdiv(M, TILE_M), 1, 1)
layer_norm_persistent_kernel_multiline[grid](
x,
y,
weight,
bias,
mean,
rstd,
M,
N,
eps,
TILE_M,
TILE_N,
)
elif N <= 4096:
TILE_N = triton.next_power_of_2(N)
grid = (M, 1, 1)
layer_norm_persistent_kernel[grid](
x,
y,
weight,
bias,
mean,
rstd,
M,
N,
eps,
TILE_N,
)
else:
grid = (M, 1, 1)
layer_norm_loop_kernel[grid](
x,
y,
weight,
bias,
mean,
rstd,
M,
N,
eps,
)
ctx.save_for_backward(x, weight, mean, rstd)
ctx.M = M
ctx.N = N
Expand Down
2 changes: 1 addition & 1 deletion tests/test_norm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype):
@pytest.mark.layer_norm
@pytest.mark.native_layer_norm
@pytest.mark.parametrize(
"shape", [(1, 40999)] if QUICK_MODE else [(1, 40999), (4096, 256)]
"shape", [(1, 40999)] if QUICK_MODE else [(1, 40999), (4096, 256), (4096, 100)]
)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_layernorm(shape, dtype):
Expand Down

0 comments on commit da86496

Please sign in to comment.