Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

work in progress optimize softmax: use better task partitioning #76

Merged
merged 3 commits into from
Jul 2, 2024

Conversation

iclementine
Copy link
Collaborator

@iclementine iclementine commented Jun 18, 2024

Use different kernels (inner & non_inner) for softmax forward & backward.
1. inner: for reduction the last dim(and the input is preprocessed to be contiguous)
2. inner: for reduce along other dimensions(and the input is preprocessed to be contiguous)

Both have ONE_TILE_PER_CTA static condition
1. when ONE_TILE_PER_CTA is True, load only one tile per cta without looping over reduction dim
2. when ONE_TILE_PER_CTA is False, use online softmax normalizer algorithm to save one swipe over the input.

The non_inner kernels now have a better task partitioning to achieve better coalescing in global memory access. For a contiguous tensor with shape (M, N, K), and the reduction is applied in axis-1, if K>1, it is considered a non-inner reduction.
In this case, the shape of the data tile is (TILE_N, TILE_K). When TILE_K is large enough, access to global memory can be coalesced.

…rd, both have ONE_TILE_PER_CTA static condition(to decide whether to load only one tile per cta.
@iclementine iclementine mentioned this pull request Jun 18, 2024
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@iclementine iclementine changed the title optimize softmax: use better task partitioning work in progress optimize softmax: use better task partitioning Jun 19, 2024
for _ in range(0, N, TILE_N):
mask = (n_offsets[:, None] < N) & (k_offsets < K)
input_ptrs = input_ptr + offset
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
Copy link
Collaborator

@pingzhuu pingzhuu Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be a slight improvement in performance.

Suggested change
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"), eviction_policy="evict_last")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you see any improvement in real tests? I'm just curious. All blocks are loaded with eviction last hints is no different than LRU I guess.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are my test results on the 4090 with shape=(64, N), dtype=float16

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it take into account the evict_first loads also? They do improve perf.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use evict_first in the first loop will lead to a lower performance

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I wasn't referring to the first loop. The evict_first policy in the second loop seems to have more clear impact.

for _ in range(0, N, TILE_N):
mask = (n_offsets[:, None] < N) & (k_offsets < K)
input_ptrs = input_ptr + offset
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"), eviction_policy="evict_first")

input_ptr,
M,
N,
TILE_M: tl.constexpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to do tiling along the m dimension? I think one block handle one row is enough.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when N is small, then tiling along the M dimension is better. I will check the result of autotunning to see whether it is required.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For inner-reduction, this is not needed in most the cases.

BTW, for non-inner reduction, tiling along the K-dimension is needed to achieve higher performance.

mask = (m_offsets[:, None] < M) & (n_offsets < N)
input_ptrs = input_ptr + offset
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
m_new = tl.maximum(m, tl.max(inp, 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using tl.max will cause a reduction. We can move it outside the loop. The following code may help:

code
import triton
import torch
import pytest

import triton.language as tl
from torch import Tensor

import flag_gems


@triton.jit
def _triton_softmax_fwd(
    Y,
    X,
    x_stride_r,
    x_stride_c,
    y_stride_r,
    y_stride_c,
    N,
    BLOCK_SIZE: tl.constexpr,
    ONE_STAGE: tl.constexpr,
):
    pid = tl.program_id(0)
    Y += pid * y_stride_r
    X += pid * x_stride_r

    if ONE_STAGE:
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be greater than N
        mask = tl.arange(0, BLOCK_SIZE) < N
        cols = tl.arange(0, BLOCK_SIZE) * x_stride_c
        row = tl.load(X + cols, mask=mask, other=-float("inf")).to(tl.float32)
        row_minus_max = row - tl.max(row, axis=0)
        row_exp = tl.exp(row_minus_max)
        acc = tl.sum(row_exp, axis=0)
        out = row_exp / acc
        # Write back output to DRAM
        tl.store(Y + cols, out, mask=mask)
    else:
        m_i = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
        l_i = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
        for k in range(0, N, BLOCK_SIZE):
            cols = tl.arange(0, BLOCK_SIZE) + k
            mask = cols < N
            vals = tl.load(
                X + cols * x_stride_c,
                mask=mask,
                other=float("-inf"),
                eviction_policy="evict_last",
            ).to(tl.float32)
            m_i_new = tl.maximum(m_i, vals)
            l_i_new = l_i * tl.exp(m_i - m_i_new) + tl.exp(vals - m_i_new)

            l_i = l_i_new
            m_i = m_i_new

        m_i_max = tl.max(m_i)
        scale = tl.exp(m_i - tl.max(m_i))
        acc = tl.sum(l_i * scale)

        for k in range(0, N, BLOCK_SIZE):
            cols = tl.arange(0, BLOCK_SIZE) + k
            mask = cols < N
            vals = tl.load(
                X + cols * x_stride_c,
                mask=mask,
                other=float("-inf"),
                eviction_policy="evict_first",
            ).to(tl.float32)
            out = tl.exp(vals - m_i_max) / acc
            tl.store(Y + cols * y_stride_c, out, mask=mask)


def softmax(input: Tensor, dim: int = -1, dtype=None) -> Tensor:
    need_transpose = dim != -1 and dim != (len(input.shape) - 1)
    if need_transpose:
        input = input.transpose(-1, dim)
    input_shape = input.shape
    input = input.view(-1, input.shape[-1])
    output = torch.empty_like(input, dtype=dtype)
    M, N = input.shape

    ONE_STAGE = N <= 32768
    BLOCK_SIZE = triton.next_power_of_2(N) if ONE_STAGE else 4096
    num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 4096 else 16)

    # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
    # f the input matrix
    _triton_softmax_fwd[(M,)](
        output,
        input,
        input.stride(0),
        input.stride(1),
        output.stride(0),
        output.stride(1),
        N,
        num_warps=num_warps,
        BLOCK_SIZE=BLOCK_SIZE,
        ONE_STAGE=ONE_STAGE,
    )
    output = output.view(input_shape)
    if need_transpose:
        output = output.transpose(-1, dim)
    return output


implementations = {
    "Torch": torch.nn.functional.softmax,
    "SiliconFlow": softmax,
    "Flaggems": flag_gems.softmax,
}
line_names = list(implementations.keys())
styles = [
    ("blue", "-"),
    ("green", "-"),
    ("red", "-"),
]


@pytest.mark.parametrize("x_name", ["M"])
@pytest.mark.parametrize("x_vals", [[128 * i for i in range(1, 41, 8)]])
@pytest.mark.parametrize("other_dim", [4096, 16384, 65536, 100000])
@pytest.mark.parametrize("dtype", [torch.float16], ids=lambda x: f"{x}")
def test_benchmark_softmax(x_name, x_vals, other_dim, dtype):
    if x_name == "M":
        plot_name = f"softmax_M_{other_dim}_{dtype}"
        args = {"N": other_dim, "dtype": dtype}
    else:
        plot_name = f"softmax_{other_dim}_N_{dtype}"
        args = {"M": other_dim, "dtype": dtype}

    import os

    os.makedirs("results", exist_ok=True)
    torch.cuda.set_stream(torch.cuda.Stream())

    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=[x_name],
            x_vals=x_vals,
            line_arg="provider",
            line_vals=line_names,
            line_names=line_names,
            styles=styles,
            ylabel="GB/s",
            plot_name=plot_name,
            args=args,
        )
    )
    def benchmark(M, N, dtype, provider):
        x = torch.randn(M, N, device="cuda", dtype=dtype)

        functor = implementations[provider]
        ms, min_ms, max_ms = triton.testing.do_bench(
            (lambda: functor(x, dim=-1)), warmup=5, rep=10, quantiles=[0.5, 0.2, 0.8]
        )

        gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
        return gbps(ms), gbps(min_ms), gbps(max_ms)

    benchmark.run(show_plots=True, print_data=True, save_path="results")

image

Copy link
Contributor

@tongxin tongxin Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. tl.max was not supposed to be included in the loop. They generate redundant warp reductions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use a better tile size search space, since in this PR now, the tile size is 1024, which is too small.

As to the modification in algorithm, I will test it. Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the code. It achieves good performance on most cases for inner and outer reducetion. Nice job.

But a contiguous tensor of shape(m, n, k), transposed into (m, k, n) cannot be viewed as (mk, n) without copying the data. I replaced the view with reshape, but the performance is not good at some cases for middle reduction.

…tioning(considering TILE_SIZE and number of blocks), use better eviction policy, better algorithm for softmax online normalizer, and reverse the second loop, specialize the last iteration, etc.
@iclementine
Copy link
Collaborator Author

performance

throughput (GB/s) for the forward function on RTX-3090, softmax for tensor of dtype float32.

inner

inner_reduction, pre_size: 10:
   reduction_size        aten          sf   flag_gems
0         32768.0   98.461540  196.923079  196.923079
1         43690.0  103.432764  126.417824  148.403529
2         54613.0  106.666018  133.724293  158.023727
3         65536.0  113.777776  196.923079  215.013129
4         76458.0  109.038793  145.689789  165.924473
5         87381.0  106.666260  151.703123  166.630438
6         98304.0  108.169011  207.567559  232.727271
7        109226.0  106.666018  152.380020  174.148603
8        120149.0  104.296005  148.994297  177.106428
9        131072.0  108.936167  213.333326  238.139535
inner_reduction, pre_size: 128:
   reduction_size        aten          sf   flag_gems
0         32768.0  394.795186  655.360017  655.360017
1         43690.0  407.960313  490.898871  520.119052
2         54613.0  401.566169  455.108345  496.481814
3         65536.0  407.055892  555.389814  560.136737
4         76458.0  406.691478  483.911407  503.013180
5         87381.0  402.677432  464.792540  490.904489
6         98304.0  401.240812  540.131858  549.184373
7        109226.0  401.566169  464.791490  487.616065
8        120149.0  404.542099  458.583993  488.410548
9        131072.0  413.476343  551.012070  557.753193
inner_reduction, pre_size: 1024:
   reduction_size        aten          sf   flag_gems
0         32768.0  418.092503  809.086437  809.242505
1         43690.0  418.086114  507.285916  511.742298
2         54613.0  409.853671  487.684117  490.352427
3         65536.0  416.483796  567.411260  569.878248
4         76458.0  411.340947  476.745132  486.219383
5         87381.0  409.278696  469.790311  477.491812
6         98304.0  416.763105  563.346690  562.138694
7        109226.0  406.990209  458.451216  465.286480
8        120149.0  404.920203  452.112911  461.224560
9        131072.0  418.760388  565.574964  565.880187
inner_reduction, pre_size: 4096:
   reduction_size        aten          sf   flag_gems
0         32768.0  431.335273  830.884346  831.502037
1         43690.0  431.373033  508.577669  509.318771
2         54613.0  424.384637  493.258824  490.214880
3         65536.0  429.392307  563.902137  566.644689
4         76458.0  421.401311  477.769175  481.909811
5         87381.0  419.282064  468.295423  473.048886
6         98304.0  428.631707  562.943433  563.750526
7        109226.0  417.341124  459.415341  462.087793
8        120149.0  415.158315  446.908547  457.601514
9        131072.0  427.990199  562.163802  563.750550

outer

outer_reduction, post_size: 10:
   reduction_size      aten         sf  flag_gems
0            32.0  0.125000   0.096154   0.416667
1          3669.0  4.623236  11.024640  25.550139
2          7306.0  4.877169  21.953125  33.575367
3         10944.0  5.000000  30.535714  29.482758
4         14581.0  5.040445  39.280710  32.546874
5         18218.0  5.065058  37.454771  32.347301
6         21856.0  5.097015  38.806818  34.846940
7         25493.0  5.106771  40.645728  34.338631
8         29130.0  5.125633  41.377841  36.123512
9         32768.0  5.140562  41.967214  37.101448
outer_reduction, post_size: 128:
   reduction_size      aten          sf   flag_gems
0            32.0  1.523810    1.183815    5.333333
1          3669.0  1.206908  107.911763  183.163809
2          7306.0  1.100964  114.156250  173.952382
3         10944.0  0.993194  113.999996  185.197254
4         14581.0  0.887624  115.722223  189.363644
5         18218.0  0.889768  112.391745  193.808505
6         21856.0  0.889866  115.640213  196.900899
7         25493.0  0.886772  117.479266  197.955838
8         29130.0  0.887596  118.414629  199.520540
9         32768.0  0.887806  119.591245  201.030670
outer_reduction, post_size: 1024:
   reduction_size      aten          sf   flag_gems
0            32.0  8.000000    9.481481   37.406394
1          3669.0  6.925908  142.485439  489.200012
2          7306.0  6.930867  140.500002  495.322019
3         10944.0  6.932067  137.014090  478.426239
4         14581.0  6.933429  138.044975  468.465855
5         18218.0  6.933587  133.587529  452.621108
6         21856.0  6.936211  135.646234  449.480727
7         25493.0  6.933805  136.783372  437.648071
8         29130.0  6.938392  137.730500  428.382345
9         32768.0  6.937779  138.847453  422.132036
outer_reduction, post_size: 4096:
   reduction_size       aten          sf   flag_gems
0            32.0  32.000000   36.571428  113.777774
1          3669.0  27.297838  151.494189  546.083739
2          7306.0  27.344092  148.063328  552.699776
3         10944.0  27.360001  142.709043  555.885693
4         14581.0  27.364495  145.174859  558.124388
5         18218.0  27.359490  140.572474  557.871748
6         21856.0  27.347774  141.348425  558.619814
7         25493.0  27.364932  142.618187  558.750663
8         29130.0  27.367371  143.608076  559.519809
9         32768.0  27.359391  144.671086  557.484446

middle

middle_reduction, pre_size: 128, post_size: 128:
   reduction_size        aten   flag_gems
0            32.0  180.147444   48.427377
1           256.0  105.433551  404.370465
2           480.0  106.202331  509.763814
3           704.0  106.548959  527.197167
4           928.0  106.548411  533.413104
5          1152.0  106.776976  530.891537
6          1376.0  106.958652  539.371828
7          1600.0  107.084139  538.164343
8          1824.0  106.987697  540.660002
9          2048.0  107.035593  543.062987
middle_reduction, pre_size: 128, post_size: 1024:
   reduction_size        aten   flag_gems
0            32.0  336.418763  347.651130
1           256.0  329.880774  548.142167
2           480.0  326.672338  544.096432
3           704.0  329.197793  548.941028
4           928.0  326.083256  547.280923
5          1152.0  329.702632  549.114227
6          1376.0  329.783541  549.821045
7          1600.0  329.344425  550.046866
8          1824.0  327.507286  548.083169
9          2048.0  323.959537  550.125228
middle_reduction, pre_size: 1024, post_size: 128:
   reduction_size        aten   flag_gems
0            32.0  333.350601  330.436084
1           256.0  303.488434  555.473694
2           480.0  284.351859  558.454913
3           704.0  288.704236  556.400912
4           928.0  290.501996  560.200720
5          1152.0  276.786006  558.245433
6          1376.0  287.837470  561.228678
7          1600.0  291.423306  559.765604
8          1824.0  284.014062  558.429457
9          2048.0  289.978973  560.532656

@iclementine iclementine mentioned this pull request Jul 1, 2024
Copy link
Collaborator

@pingzhuu pingzhuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tongxin tongxin merged commit 36ee6e3 into master Jul 2, 2024
3 checks passed
@iclementine iclementine deleted the opt-softmax branch July 15, 2024 07:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants