-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
work in progress optimize softmax: use better task partitioning #76
Conversation
…rd, both have ONE_TILE_PER_CTA static condition(to decide whether to load only one tile per cta.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
src/flag_gems/ops/softmax.py
Outdated
for _ in range(0, N, TILE_N): | ||
mask = (n_offsets[:, None] < N) & (k_offsets < K) | ||
input_ptrs = input_ptr + offset | ||
inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be a slight improvement in performance.
inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) | |
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"), eviction_policy="evict_last") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you see any improvement in real tests? I'm just curious. All blocks are loaded with eviction last hints is no different than LRU I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it take into account the evict_first loads also? They do improve perf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use evict_first
in the first loop will lead to a lower performance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I wasn't referring to the first loop. The evict_first policy in the second loop seems to have more clear impact.
src/flag_gems/ops/softmax.py
Outdated
for _ in range(0, N, TILE_N): | ||
mask = (n_offsets[:, None] < N) & (k_offsets < K) | ||
input_ptrs = input_ptr + offset | ||
inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) | |
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"), eviction_policy="evict_first") |
src/flag_gems/ops/softmax.py
Outdated
input_ptr, | ||
M, | ||
N, | ||
TILE_M: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to do tiling along the m dimension? I think one block handle one row is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think when N
is small, then tiling along the M dimension is better. I will check the result of autotunning to see whether it is required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For inner-reduction, this is not needed in most the cases.
BTW, for non-inner reduction, tiling along the K-dimension is needed to achieve higher performance.
src/flag_gems/ops/softmax.py
Outdated
mask = (m_offsets[:, None] < M) & (n_offsets < N) | ||
input_ptrs = input_ptr + offset | ||
inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) | ||
m_new = tl.maximum(m, tl.max(inp, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using tl.max
will cause a reduction. We can move it outside the loop. The following code may help:
code
import triton
import torch
import pytest
import triton.language as tl
from torch import Tensor
import flag_gems
@triton.jit
def _triton_softmax_fwd(
Y,
X,
x_stride_r,
x_stride_c,
y_stride_r,
y_stride_c,
N,
BLOCK_SIZE: tl.constexpr,
ONE_STAGE: tl.constexpr,
):
pid = tl.program_id(0)
Y += pid * y_stride_r
X += pid * x_stride_r
if ONE_STAGE:
# Load the row into SRAM, using a mask since BLOCK_SIZE may be greater than N
mask = tl.arange(0, BLOCK_SIZE) < N
cols = tl.arange(0, BLOCK_SIZE) * x_stride_c
row = tl.load(X + cols, mask=mask, other=-float("inf")).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
row_exp = tl.exp(row_minus_max)
acc = tl.sum(row_exp, axis=0)
out = row_exp / acc
# Write back output to DRAM
tl.store(Y + cols, out, mask=mask)
else:
m_i = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
l_i = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
for k in range(0, N, BLOCK_SIZE):
cols = tl.arange(0, BLOCK_SIZE) + k
mask = cols < N
vals = tl.load(
X + cols * x_stride_c,
mask=mask,
other=float("-inf"),
eviction_policy="evict_last",
).to(tl.float32)
m_i_new = tl.maximum(m_i, vals)
l_i_new = l_i * tl.exp(m_i - m_i_new) + tl.exp(vals - m_i_new)
l_i = l_i_new
m_i = m_i_new
m_i_max = tl.max(m_i)
scale = tl.exp(m_i - tl.max(m_i))
acc = tl.sum(l_i * scale)
for k in range(0, N, BLOCK_SIZE):
cols = tl.arange(0, BLOCK_SIZE) + k
mask = cols < N
vals = tl.load(
X + cols * x_stride_c,
mask=mask,
other=float("-inf"),
eviction_policy="evict_first",
).to(tl.float32)
out = tl.exp(vals - m_i_max) / acc
tl.store(Y + cols * y_stride_c, out, mask=mask)
def softmax(input: Tensor, dim: int = -1, dtype=None) -> Tensor:
need_transpose = dim != -1 and dim != (len(input.shape) - 1)
if need_transpose:
input = input.transpose(-1, dim)
input_shape = input.shape
input = input.view(-1, input.shape[-1])
output = torch.empty_like(input, dtype=dtype)
M, N = input.shape
ONE_STAGE = N <= 32768
BLOCK_SIZE = triton.next_power_of_2(N) if ONE_STAGE else 4096
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 4096 else 16)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
_triton_softmax_fwd[(M,)](
output,
input,
input.stride(0),
input.stride(1),
output.stride(0),
output.stride(1),
N,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
ONE_STAGE=ONE_STAGE,
)
output = output.view(input_shape)
if need_transpose:
output = output.transpose(-1, dim)
return output
implementations = {
"Torch": torch.nn.functional.softmax,
"SiliconFlow": softmax,
"Flaggems": flag_gems.softmax,
}
line_names = list(implementations.keys())
styles = [
("blue", "-"),
("green", "-"),
("red", "-"),
]
@pytest.mark.parametrize("x_name", ["M"])
@pytest.mark.parametrize("x_vals", [[128 * i for i in range(1, 41, 8)]])
@pytest.mark.parametrize("other_dim", [4096, 16384, 65536, 100000])
@pytest.mark.parametrize("dtype", [torch.float16], ids=lambda x: f"{x}")
def test_benchmark_softmax(x_name, x_vals, other_dim, dtype):
if x_name == "M":
plot_name = f"softmax_M_{other_dim}_{dtype}"
args = {"N": other_dim, "dtype": dtype}
else:
plot_name = f"softmax_{other_dim}_N_{dtype}"
args = {"M": other_dim, "dtype": dtype}
import os
os.makedirs("results", exist_ok=True)
torch.cuda.set_stream(torch.cuda.Stream())
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=[x_name],
x_vals=x_vals,
line_arg="provider",
line_vals=line_names,
line_names=line_names,
styles=styles,
ylabel="GB/s",
plot_name=plot_name,
args=args,
)
)
def benchmark(M, N, dtype, provider):
x = torch.randn(M, N, device="cuda", dtype=dtype)
functor = implementations[provider]
ms, min_ms, max_ms = triton.testing.do_bench(
(lambda: functor(x, dim=-1)), warmup=5, rep=10, quantiles=[0.5, 0.2, 0.8]
)
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(min_ms), gbps(max_ms)
benchmark.run(show_plots=True, print_data=True, save_path="results")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. tl.max was not supposed to be included in the loop. They generate redundant warp reductions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use a better tile size search space, since in this PR now, the tile size is 1024, which is too small.
As to the modification in algorithm, I will test it. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested the code. It achieves good performance on most cases for inner and outer reducetion. Nice job.
But a contiguous tensor of shape(m, n, k)
, transposed into (m, k, n)
cannot be viewed as (mk, n)
without copying the data. I replaced the view
with reshape
, but the performance is not good at some cases for middle reduction.
…tioning(considering TILE_SIZE and number of blocks), use better eviction policy, better algorithm for softmax online normalizer, and reverse the second loop, specialize the last iteration, etc.
performancethroughput (GB/s) for the forward function on RTX-3090, softmax for tensor of dtype float32. inner
outer
middle
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Use different kernels (inner & non_inner) for softmax forward & backward.
1. inner: for reduction the last dim(and the input is preprocessed to be contiguous)
2. inner: for reduce along other dimensions(and the input is preprocessed to be contiguous)
Both have
ONE_TILE_PER_CTA
static condition1. when
ONE_TILE_PER_CTA
is True, load only one tile per cta without looping over reduction dim2. when
ONE_TILE_PER_CTA
is False, use online softmax normalizer algorithm to save one swipe over the input.The
non_inner
kernels now have a better task partitioning to achieve better coalescing in global memory access. For a contiguous tensor with shape(M, N, K)
, and the reduction is applied in axis-1, ifK>1
, it is considered a non-inner reduction.In this case, the shape of the data tile is
(TILE_N, TILE_K)
. WhenTILE_K
is large enough, access to global memory can be coalesced.