diff --git a/src/flag_gems/ops/layernorm.py b/src/flag_gems/ops/layernorm.py index a9cb6291..87d22110 100644 --- a/src/flag_gems/ops/layernorm.py +++ b/src/flag_gems/ops/layernorm.py @@ -9,9 +9,13 @@ def cfggen(): - warps = [1, 2, 4, 8, 16, 32] + block_m = [1, 2, 4] + block_n = [1024, 2048, 4096] + warps = [4, 8, 16] configs = [ - triton.Config({"BLOCK_ROW_SIZE": 1, "BLOCK_COL_SIZE": 2048}, num_warps=w) + triton.Config({"BLOCK_ROW_SIZE": m, "BLOCK_COL_SIZE": n}, num_warps=w) + for m in block_m + for n in block_n for w in warps ] return configs @@ -203,9 +207,10 @@ class LayerNorm(torch.autograd.Function): @staticmethod def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True): logging.debug("GEMS LAYERNORM FORWARD") - dim = x.ndim - len(normalized_shape) - M = math.prod(x.shape[:dim]) + # dim = x.ndim - len(normalized_shape) + # M = math.prod(x.shape[:dim]) N = math.prod(normalized_shape) + M = x.numel() // N x = x.contiguous() weight = weight.contiguous() bias = bias.contiguous() diff --git a/src/flag_gems/ops/sigmoid.py b/src/flag_gems/ops/sigmoid.py index 833e9559..908b5226 100644 --- a/src/flag_gems/ops/sigmoid.py +++ b/src/flag_gems/ops/sigmoid.py @@ -27,9 +27,13 @@ class Sigmoid(torch.autograd.Function): @staticmethod def forward(ctx, A): logging.debug("GEMS SIGMOID FORWARD") - out = sigmoid_forward(A.to(torch.float32)) - ctx.save_for_backward(out) - return out.to(A.dtype) + if A.requires_grad is True: + out = sigmoid_forward(A.to(torch.float32)) + ctx.save_for_backward(out) + return out.to(A.dtype) + else: + out = sigmoid_forward(A) + return out @staticmethod def backward(ctx, out_grad): diff --git a/src/flag_gems/ops/tanh.py b/src/flag_gems/ops/tanh.py index b2cff858..448de6a0 100644 --- a/src/flag_gems/ops/tanh.py +++ b/src/flag_gems/ops/tanh.py @@ -23,9 +23,13 @@ class Tanh(torch.autograd.Function): @staticmethod def forward(ctx, A): logging.debug("GEMS TANH FORWARD") - out = tanh_forward(A.to(torch.float32)) - ctx.save_for_backward(out) - return out.to(A.dtype) + if A.requires_grad is True: + out = tanh_forward(A.to(torch.float32)) + ctx.save_for_backward(out) + return out.to(A.dtype) + else: + out = tanh_forward(A) + return out @staticmethod def backward(ctx, out_grad): diff --git a/src/flag_gems/ops/vector_norm.py b/src/flag_gems/ops/vector_norm.py index c7c26524..d9861211 100644 --- a/src/flag_gems/ops/vector_norm.py +++ b/src/flag_gems/ops/vector_norm.py @@ -1,4 +1,5 @@ import logging +import math import torch import triton @@ -38,6 +39,31 @@ def l2_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): tl.store(Out, out, row_mask) +@libentry() +@triton.jit +def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + X = X + offset + Mid = Mid + pid + mask = offset < M + + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + mid = tl.sum(x * x) + tl.store(Mid, mid) + + +@libentry() +@triton.jit +def l2_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): + offset = tl.arange(0, BLOCK_MID) + Mid = Mid + offset + mask = offset < MID_SIZE + mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) + out = tl.sqrt(tl.sum(mid)) + tl.store(Out, out) + + @libentry() @triton.autotune(configs=cfggen(), key=["M", "N"]) @triton.jit @@ -61,6 +87,31 @@ def max_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): tl.store(Out, out, row_mask) +@libentry() +@triton.jit +def max_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + X = X + offset + Mid = Mid + pid + mask = offset < M + + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + mid = tl.max(tl.abs(x)) + tl.store(Mid, mid) + + +@libentry() +@triton.jit +def max_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): + offset = tl.arange(0, BLOCK_MID) + Mid = Mid + offset + mask = offset < MID_SIZE + mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) + out = tl.max(mid) + tl.store(Out, out) + + @libentry() @triton.autotune(configs=cfggen(), key=["M", "N"]) @triton.jit @@ -84,6 +135,31 @@ def min_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): tl.store(Out, out, row_mask) +@libentry() +@triton.jit +def min_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + X = X + offset + Mid = Mid + pid + mask = offset < M + + x = tl.load(X, mask=mask, other=float("inf")).to(tl.float32) + mid = tl.min(tl.abs(x)) + tl.store(Mid, mid) + + +@libentry() +@triton.jit +def min_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): + offset = tl.arange(0, BLOCK_MID) + Mid = Mid + offset + mask = offset < MID_SIZE + mid = tl.load(Mid, mask=mask, other=float("inf")).to(tl.float32) + out = tl.min(mid) + tl.store(Out, out) + + @libentry() @triton.autotune(configs=cfggen(), key=["M", "N"]) @triton.jit @@ -106,6 +182,32 @@ def l0_norm_kernel(X, Out, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): tl.store(Out, out, row_mask) +@libentry() +@triton.jit +def l0_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + X = X + offset + Mid = Mid + pid + mask = offset < M + + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + cnt = (x != 0).to(tl.float32) + mid = tl.sum(cnt) + tl.store(Mid, mid) + + +@libentry() +@triton.jit +def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr): + offset = tl.arange(0, BLOCK_MID) + Mid = Mid + offset + mask = offset < MID_SIZE + mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) + out = tl.sum(mid) + tl.store(Out, out) + + @libentry() @triton.autotune(configs=cfggen(), key=["M", "N"]) @triton.jit @@ -128,6 +230,31 @@ def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexp tl.store(Out, out, row_mask) +@libentry() +@triton.jit +def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + X = X + offset + Mid = Mid + pid + mask = offset < M + + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + mid = tl.sum(tl.math.pow(tl.abs(x), ord)) + tl.store(Mid, mid) + + +@libentry() +@triton.jit +def l1_norm_kernel_2(Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr): + offset = tl.arange(0, BLOCK_MID) + Mid = Mid + offset + mask = offset < MID_SIZE + mid = tl.load(Mid, mask=mask, other=0.0).to(tl.float32) + out = tl.math.pow(tl.sum(mid), 1 / ord) + tl.store(Out, out) + + def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None): logging.debug("GEMS VECTOR NORM") if dtype is not None: @@ -136,28 +263,54 @@ def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None): dtype = x.dtype if dtype not in [torch.float16, torch.float32, torch.bfloat16]: raise NotImplementedError(f"vector_norm not implemented for {dtype}") - if dim is None: + + if dim is None or len(dim) == x.ndim: dim = list(range(x.ndim)) - shape = list(x.shape) - dim = [d % x.ndim for d in dim] - x = dim_compress(x, dim) - N = 1 - for i in dim: - N *= shape[i] - shape[i] = 1 - M = x.numel() // N - out = torch.empty(shape, dtype=dtype, device=x.device) - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) - if ord == 2: - l2_norm_kernel[grid](x, out, M, N) - elif ord == float("inf"): - max_norm_kernel[grid](x, out, M, N) - elif ord == -float("inf"): - min_norm_kernel[grid](x, out, M, N) - elif ord == 0: - l0_norm_kernel[grid](x, out, M, N) + shape = [1] * x.ndim + x = dim_compress(x, dim) + M = x.numel() + BLOCK_SIZE = triton.next_power_of_2(math.ceil(math.sqrt(M))) + MID_SIZE = triton.cdiv(M, BLOCK_SIZE) + BLOCK_MID = triton.next_power_of_2(MID_SIZE) + + mid = torch.empty([MID_SIZE], dtype=dtype, device=x.device) + out = torch.empty(shape, dtype=dtype, device=x.device) + if ord == 2: + l2_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) + l2_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) + elif ord == float("inf"): + max_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) + max_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) + elif ord == -float("inf"): + min_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) + min_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) + elif ord == 0: + l0_norm_kernel_1[(MID_SIZE,)](x, mid, M, BLOCK_SIZE) + l0_norm_kernel_2[(1,)](mid, out, MID_SIZE, BLOCK_MID) + else: + l1_norm_kernel_1[(MID_SIZE,)](x, mid, ord, M, BLOCK_SIZE) + l1_norm_kernel_2[(1,)](mid, out, ord, MID_SIZE, BLOCK_MID) else: - v_norm_kernel[grid](x, out, M, N, ord) + shape = list(x.shape) + dim = [d % x.ndim for d in dim] + x = dim_compress(x, dim) + N = 1 + for i in dim: + N *= shape[i] + shape[i] = 1 + M = x.numel() // N + out = torch.empty(shape, dtype=dtype, device=x.device) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) + if ord == 2: + l2_norm_kernel[grid](x, out, M, N) + elif ord == float("inf"): + max_norm_kernel[grid](x, out, M, N) + elif ord == -float("inf"): + min_norm_kernel[grid](x, out, M, N) + elif ord == 0: + l0_norm_kernel[grid](x, out, M, N) + else: + v_norm_kernel[grid](x, out, M, N, ord) if not keepdim: out = out.squeeze(dim=dim) return out