Skip to content

Commit

Permalink
fmt + lint
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Dec 12, 2024
1 parent 742ae79 commit b2dc5ca
Show file tree
Hide file tree
Showing 9 changed files with 1,313 additions and 429 deletions.
7 changes: 5 additions & 2 deletions tests/models/decoder_only/language/test_bamba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.
This actually is really indentical to test_mamba, so maybe we can reuse
This actually is really identical to test_mamba, so maybe we can reuse
Run `pytest tests/models/decoder_only/language/test_bamba.py`.
"""
Expand Down Expand Up @@ -97,6 +97,7 @@ def test_batching(
name_1="batched_vllm",
)


@pytest.mark.skip("bamba does not support chunked prefill yet")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
Expand All @@ -122,6 +123,7 @@ def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@pytest.mark.skip("bamba does not support chunked prefill yet")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
Expand Down Expand Up @@ -205,7 +207,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size(len(example_prompts)):
while len(example_prompts) == VllmConfig.get_graph_batch_size(
len(example_prompts)):
example_prompts.append(example_prompts[0])

try:
Expand Down
233 changes: 127 additions & 106 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions vllm/model_executor/layers/mamba/ops/softplus.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/softplus.py

# ruff: noqa: E501

import triton
import triton.language as tl
from packaging import version

TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")


if TRITON3:

@triton.jit
def softplus(dt):
return tl.math.log(tl.math.exp(dt) + 1)
else:

@triton.jit
def softplus(dt):
return tl.math.log1p(tl.exp(dt))
return tl.math.log1p(tl.exp(dt))
213 changes: 171 additions & 42 deletions vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,134 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_bmm.py

# ruff: noqa: E501,SIM102
"""We want triton==2.1.0 or 2.2.0 for this
"""

import math
import torch
import torch.nn.functional as F

import torch
import triton
import triton.language as tl

from einops import rearrange, repeat


def init_to_zero(names):
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
return lambda nargs: [
nargs[name].zero_() for name in names if nargs[name] is not None
]


@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=2),
],
key=['chunk_size', 'K', 'IS_CAUSAL'],
)
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices
a_ptr, b_ptr, out_ptr, seq_idx_ptr,
a_ptr,
b_ptr,
out_ptr,
seq_idx_ptr,
# Matrix dimensions
seqlen, chunk_size, K, ngroups,
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
stride_seq_idx_batch, stride_seq_idx_seqlen,
seqlen,
chunk_size,
K,
ngroups,
stride_a_batch,
stride_a_seqlen,
stride_a_head,
stride_ak,
stride_b_batch,
stride_b_seqlen,
stride_b_head,
stride_bk,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_outm,
stride_outn,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_ch = tl.program_id(axis=2).to(tl.int64)
Expand All @@ -65,14 +148,22 @@ def _bmm_chunk_fwd_kernel(
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
a = tl.load(a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0).to(dot_dtype)
b = tl.load(b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
(offs_n[None, :] < chunk_size_limit),
other=0.0).to(dot_dtype)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
Expand All @@ -81,16 +172,30 @@ def _bmm_chunk_fwd_kernel(
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_SEQ_IDX:
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
mask=offs_n < chunk_size_limit,
other=-2)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
out = acc.to(out_ptr.dtype.element_ty)

out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
offs_n[None, :] * stride_outn)
tl.store(out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) &
(offs_n[None, :] < chunk_size))


def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
def _bmm_chunk_fwd(a,
b,
chunk_size,
seq_idx=None,
causal=False,
output_dtype=None):
"""
Argument:
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
Expand All @@ -117,20 +222,44 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=No
nchunks = math.ceil(seqlen / chunk_size)
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
device=a.device, dtype=out_dtype)
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
(tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
batch, nchunks if not has_groups else nchunks * ngroups)
out = torch.empty(
(batch, nchunks, chunk_size, chunk_size) if not has_groups else
(batch, nchunks, ngroups, chunk_size, chunk_size),
device=a.device,
dtype=out_dtype)
dot_dtype = (tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
(tl.float16 if a.dtype == torch.float16
or b.dtype == torch.float16 else tl.float32))
grid = lambda META: (triton.cdiv(
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
chunk_size, META['BLOCK_SIZE_N']), batch, nchunks
if not has_groups else nchunks * ngroups)
with torch.cuda.device(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a, b, out, seq_idx,
seqlen, chunk_size, k, ngroups if has_groups else 1,
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
a,
b,
out,
seq_idx,
seqlen,
chunk_size,
k,
ngroups if has_groups else 1,
a.stride(0),
a.stride(1),
0 if not has_groups else a.stride(2),
a.stride(-1),
b.stride(0),
b.stride(1),
0 if not has_groups else b.stride(2),
b.stride(-1),
out.stride(0),
out.stride(1),
0 if not has_groups else out.stride(2),
out.stride(-2),
out.stride(-1),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
causal,
dot_dtype,
HAS_SEQ_IDX=seq_idx is not None,
Expand Down
Loading

0 comments on commit b2dc5ca

Please sign in to comment.