Skip to content

Commit

Permalink
some fixes and remove unused kernels
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Dec 12, 2024
1 parent 0f93e4a commit 742ae79
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 3,008 deletions.
4 changes: 2 additions & 2 deletions tests/models/decoder_only/language/test_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size

from ...utils import check_outputs_equal

Expand Down Expand Up @@ -205,7 +205,7 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
while len(example_prompts) == VllmConfig.get_graph_batch_size(len(example_prompts)):

Check failure on line 208 in tests/models/decoder_only/language/test_bamba.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/models/decoder_only/language/test_bamba.py:208:81: E501 Line too long (88 > 80)
example_prompts.append(example_prompts[0])

try:
Expand Down
124 changes: 0 additions & 124 deletions vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,76 +90,6 @@ def _bmm_chunk_fwd_kernel(
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))


@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
],
key=['chunk_size', 'K'],
)
@triton.jit
def _bmm_chunk_bwd_kernel(
# Pointers to matrices
a_ptr, dout_ptr, db_ptr, res_ptr,
# Matrix dimensions
seqlen, chunk_size, K, ngroups,
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
# Meta-parameters
dot_dtype: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_ch = tl.program_id(axis=2)
pid_c = pid_ch // ngroups
pid_h = pid_ch - pid_c * ngroups
num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n

a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_cs = tl.arange(0, BLOCK_SIZE_CS)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
acc += tl.dot(dout, a)
dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_RESIDUAL:
res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
acc += res
db = acc.to(db_ptr.dtype.element_ty)

db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))


def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
"""
Argument:
Expand Down Expand Up @@ -206,57 +136,3 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=No
HAS_SEQ_IDX=seq_idx is not None,
)
return out


def _bmm_chunk_bwd(a, dout, residual=None, out=None):
"""
Argument:
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
Return:
out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
zeroed out before calling this function.
"""
# Check constraints.
has_groups = a.dim() == 4
if not has_groups:
batch, seqlen, k = a.shape
else:
batch, seqlen, ngroups, k = a.shape
nchunks, chunk_size = dout.shape[1], dout.shape[-1]
if a.stride(-1) != 1 and a.stride(-2) != 1:
a = a.contiguous()
if dout.stride(-1) != 1 and dout.stride(-2) != 1:
dout = dout.contiguous()
if residual is not None:
assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
if residual.stride(-1) != 1 and residual.stride(1) != 1:
residual = residual.contiguous()
# Allocates output.
if out is not None:
assert out.shape == a.shape
assert out.stride(-1) == 1 or out.stride(1) == 1
else:
out = torch.empty_like(a)
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
(tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
nchunks if not has_groups else nchunks * ngroups)
residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
residual.stride(-1))
if residual is not None else (0, 0, 0, 0))
with torch.cuda.device(a.device.index):
_bmm_chunk_bwd_kernel[grid](
a, dout, out, residual,
seqlen, chunk_size, k, ngroups if has_groups else 1,
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
dot_dtype,
HAS_RESIDUAL=residual is not None,
)
return out
Loading

0 comments on commit 742ae79

Please sign in to comment.