Skip to content

Commit

Permalink
[Operator] optimization of vector_norm, tanh, sigmoid, layernorm (#75)
Browse files Browse the repository at this point in the history
* [Operator] implement two-pass vector_norm

* [Operator] discard upcasting when gradient is not required

* [Operator] expand config space for layernorm and improve performance

* [bugfix] process uncontiguous tensor for vector_norm
  • Loading branch information
StrongSpoon committed Jun 19, 2024
1 parent fcc56c5 commit 63e0112
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 30 deletions.
13 changes: 9 additions & 4 deletions src/flag_gems/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@


def cfggen():
warps = [1, 2, 4, 8, 16, 32]
block_m = [1, 2, 4]
block_n = [1024, 2048, 4096]
warps = [4, 8, 16]
configs = [
triton.Config({"BLOCK_ROW_SIZE": 1, "BLOCK_COL_SIZE": 2048}, num_warps=w)
triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w)
for m in block_m
for n in block_n
for w in warps
]
return configs
Expand Down Expand Up @@ -203,9 +207,10 @@ class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True):
logging.debug("GEMS LAYERNORM FORWARD")
dim = x.ndim - len(normalized_shape)
M = math.prod(x.shape[:dim])
# dim = x.ndim - len(normalized_shape)
# M = math.prod(x.shape[:dim])
N = math.prod(normalized_shape)
M = x.numel() // N
x = x.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
Expand Down
10 changes: 7 additions & 3 deletions src/flag_gems/ops/sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ class Sigmoid(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
logging.debug("GEMS SIGMOID FORWARD")
out = sigmoid_forward(A.to(torch.float32))
ctx.save_for_backward(out)
return out.to(A.dtype)
if A.requires_grad is True:
out = sigmoid_forward(A.to(torch.float32))
ctx.save_for_backward(out)
return out.to(A.dtype)
else:
out = sigmoid_forward(A)
return out

@staticmethod
def backward(ctx, out_grad):
Expand Down
10 changes: 7 additions & 3 deletions src/flag_gems/ops/tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ class Tanh(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
logging.debug("GEMS TANH FORWARD")
out = tanh_forward(A.to(torch.float32))
ctx.save_for_backward(out)
return out.to(A.dtype)
if A.requires_grad is True:
out = tanh_forward(A.to(torch.float32))
ctx.save_for_backward(out)
return out.to(A.dtype)
else:
out = tanh_forward(A)
return out

@staticmethod
def backward(ctx, out_grad):
Expand Down
193 changes: 173 additions & 20 deletions src/flag_gems/ops/vector_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math

import torch
import triton
Expand Down Expand Up @@ -38,6 +39,31 @@ def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
tl.store(Out, out, row_mask)


@libentry()
@triton.jit
def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
X = X + offset
Mid = Mid + pid
mask = offset < M

x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
mid = tl.sum(x * x)
tl.store(Mid, mid)


@libentry()
@triton.jit
def l2_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
Mid = Mid + offset
mask = offset < MID_SIZE
mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
out = tl.sqrt(tl.sum(mid))
tl.store(Out, out)


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
Expand All @@ -61,6 +87,31 @@ def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
tl.store(Out, out, row_mask)


@libentry()
@triton.jit
def max_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
X = X + offset
Mid = Mid + pid
mask = offset < M

x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
mid = tl.max(tl.abs(x))
tl.store(Mid, mid)


@libentry()
@triton.jit
def max_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
Mid = Mid + offset
mask = offset < MID_SIZE
mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
out = tl.max(mid)
tl.store(Out, out)


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
Expand All @@ -84,6 +135,31 @@ def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
tl.store(Out, out, row_mask)


@libentry()
@triton.jit
def min_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
X = X + offset
Mid = Mid + pid
mask = offset < M

x = tl.load(X, mask=mask, other=float("inf")).to(tl.float32)
mid = tl.min(tl.abs(x))
tl.store(Mid, mid)


@libentry()
@triton.jit
def min_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
Mid = Mid + offset
mask = offset < MID_SIZE
mid = tl.load(Mid, mask=mask, other=float("inf")).to(tl.float32)
out = tl.min(mid)
tl.store(Out, out)


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
Expand All @@ -106,6 +182,32 @@ def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
tl.store(Out, out, row_mask)


@libentry()
@triton.jit
def l0_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
X = X + offset
Mid = Mid + pid
mask = offset < M

x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
cnt = (x != 0).to(tl.float32)
mid = tl.sum(cnt)
tl.store(Mid, mid)


@libentry()
@triton.jit
def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
Mid = Mid + offset
mask = offset < MID_SIZE
mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
out = tl.sum(mid)
tl.store(Out, out)


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
Expand All @@ -128,6 +230,31 @@ def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexp
tl.store(Out, out, row_mask)


@libentry()
@triton.jit
def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
X = X + offset
Mid = Mid + pid
mask = offset < M

x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
mid = tl.sum(tl.math.pow(tl.abs(x), ord))
tl.store(Mid, mid)


@libentry()
@triton.jit
def l1_norm_kernel_2(Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
Mid = Mid + offset
mask = offset < MID_SIZE
mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32)
out = tl.math.pow(tl.sum(mid), 1 / ord)
tl.store(Out, out)


def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None):
logging.debug("GEMS VECTOR NORM")
if dtype is not None:
Expand All @@ -136,28 +263,54 @@ def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None):
dtype = x.dtype
if dtype not in [torch.float16, torch.float32, torch.bfloat16]:
raise NotImplementedError(f"vector_norm not implemented for {dtype}")
if dim is None:

if dim is None or len(dim) == x.ndim:
dim = list(range(x.ndim))
shape = list(x.shape)
dim = [d % x.ndim for d in dim]
x = dim_compress(x, dim)
N = 1
for i in dim:
N *= shape[i]
shape[i] = 1
M = x.numel() // N
out = torch.empty(shape, dtype=dtype, device=x.device)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
if ord == 2:
l2_norm_kernel[grid](x, out, M, N)
elif ord == float("inf"):
max_norm_kernel[grid](x, out, M, N)
elif ord == -float("inf"):
min_norm_kernel[grid](x, out, M, N)
elif ord == 0:
l0_norm_kernel[grid](x, out, M, N)
shape = [1] * x.ndim
x = dim_compress(x, dim)
M = x.numel()
BLOCK_SIZE = triton.next_power_of_2(math.ceil(math.sqrt(M)))
MID_SIZE = triton.cdiv(M, BLOCK_SIZE)
BLOCK_MID = triton.next_power_of_2(MID_SIZE)

mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device)
out = torch.empty(shape, dtype=dtype, device=x.device)
if ord == 2:
l2_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
l2_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
elif ord == float("inf"):
max_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
max_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
elif ord == -float("inf"):
min_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
min_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
elif ord == 0:
l0_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE)
l0_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID)
else:
l1_norm_kernel_1[(MID_SIZE,)](x, mid, ord, M, BLOCK_SIZE)
l1_norm_kernel_2[(1,)](mid, out, ord, MID_SIZE, BLOCK_MID)
else:
v_norm_kernel[grid](x, out, M, N, ord)
shape = list(x.shape)
dim = [d % x.ndim for d in dim]
x = dim_compress(x, dim)
N = 1
for i in dim:
N *= shape[i]
shape[i] = 1
M = x.numel() // N
out = torch.empty(shape, dtype=dtype, device=x.device)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
if ord == 2:
l2_norm_kernel[grid](x, out, M, N)
elif ord == float("inf"):
max_norm_kernel[grid](x, out, M, N)
elif ord == -float("inf"):
min_norm_kernel[grid](x, out, M, N)
elif ord == 0:
l0_norm_kernel[grid](x, out, M, N)
else:
v_norm_kernel[grid](x, out, M, N, ord)
if not keepdim:
out = out.squeeze(dim=dim)
return out

0 comments on commit 63e0112

Please sign in to comment.