From b2dc5cad298ffeb5c4d24209a1686532e7d75abc Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 06:23:45 +0000 Subject: [PATCH] fmt + lint Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_bamba.py | 7 +- .../layers/mamba/mamba_mixer2.py | 233 +++---- .../layers/mamba/ops/softplus.py | 10 +- .../layers/mamba/ops/ssd_bmm.py | 213 +++++-- .../layers/mamba/ops/ssd_chunk_scan.py | 393 +++++++++--- .../layers/mamba/ops/ssd_chunk_state.py | 592 ++++++++++++++---- .../layers/mamba/ops/ssd_combined.py | 131 +++- .../layers/mamba/ops/ssd_state_passing.py | 102 ++- vllm/model_executor/models/bamba.py | 61 +- 9 files changed, 1313 insertions(+), 429 deletions(-) diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py index a3bcb644baf8b..d266135360563 100644 --- a/tests/models/decoder_only/language/test_bamba.py +++ b/tests/models/decoder_only/language/test_bamba.py @@ -1,6 +1,6 @@ """Compare the outputs of HF and vLLM when using greedy sampling for Mamba. -This actually is really indentical to test_mamba, so maybe we can reuse +This actually is really identical to test_mamba, so maybe we can reuse Run `pytest tests/models/decoder_only/language/test_bamba.py`. """ @@ -97,6 +97,7 @@ def test_batching( name_1="batched_vllm", ) + @pytest.mark.skip("bamba does not support chunked prefill yet") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @@ -122,6 +123,7 @@ def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + @pytest.mark.skip("bamba does not support chunked prefill yet") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @@ -205,7 +207,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size( + len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b2a4b2aaefc78..150ee86b4ca3b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,36 +1,35 @@ +from typing import List, Optional, Tuple, Union + import torch from torch import nn -from torch.nn.parameter import Parameter - -# Added by the IBM Team, 2024 from vllm.attention.backends.abstract import AttentionMetadata +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) - -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_state_update) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + LoaderFunction, composed_weight_loader, sharded_weight_loader) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs -from vllm.distributed import (divide, get_tensor_model_parallel_world_size, - get_tensor_model_parallel_rank, - tensor_model_parallel_all_reduce) -from vllm.model_executor.model_loader.weight_utils import ( - composed_weight_loader, sharded_weight_loader, LoaderFunction) -from typing import Tuple, Union, Optional, List -from vllm.model_executor.custom_op import CustomOp +# Added by the IBM Team, 2024 + # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated # also referenced https://github.com/vllm-project/vllm/pull/9292 @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): + def __init__(self, hidden_size, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -84,6 +83,7 @@ def forward_cuda( ) return out + def extra_groups_for_head_shards(ngroups: int, tp_size: int): """Compute the extra (logical) groups to account for head shards""" @@ -93,12 +93,16 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int): return tp_size - ngroups % tp_size + def mamba_v2_sharded_weight_loader( - shard_spec: List[int], tp_size: int, tp_rank: int, + shard_spec: List[Tuple[int, int, float]], + tp_size: int, + tp_rank: int, ) -> LoaderFunction: - """Create a weight loader for mamba v2. This ensures that the projections are - correctly sharded so that they can be split into x, B, C. It also ensures the - the all the groups corresponding to a head shard is placed together with it. + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures the the all the groups corresponding to a head shard is placed + together with it. """ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: @@ -116,18 +120,21 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: rank = tp_rank // ratio # - should start from here (determined by rank) - loaded_skip = rank * shard_size # take these number dims from loaded + # - take these number dims from loaded + loaded_skip = rank * shard_size loaded_start_idx = loaded_boundary + loaded_skip # - these many number dims to take from loaded_weight take = min(shard_size, full_dim - extra - loaded_skip) # - always shard on dim 0 - param.data[ - boundary:boundary+take,... - ] = loaded_weight[ - loaded_start_idx:loaded_start_idx+take - ] + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[boundary:(boundary + take), # type: ignore[misc] + ...] = loaded_weight[ + loaded_start_idx:( # type: ignore[misc] + loaded_start_idx + take)] # type: ignore[misc] # move boundaries boundary += shard_size @@ -135,8 +142,9 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: return loader + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer -@CustomOp.register("mamba_mixer2") +@CustomOp.register("mamba_mixer2") class MambaMixer2(CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute @@ -165,17 +173,17 @@ def __init__(self, super().__init__() # For TP, the sharding plan is as follows: - # - for the conv modules, since + # - for the conv modules, since # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, # we shard intermediate_size and n_groups # - since intermediate_size = n_heads * head_dim, sharding on # intermediate_size is achieved by sharding on n_heads. - # - so if world_size divides groups, then sharding + # - so if world_size divides groups, then sharding # (n_groups / world_size, n_heads / world_size) # also maintains the invariant n_heads % n_groups == 0 - # - HOWEVER< if world_size DOES NOT divide groups, then we need to allocate - # extra space in the shard, such that the WHOLE GROUP must be placed - # together with the HEAD SHARD. + # - HOWEVER< if world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that the WHOLE GROUP + # must be placed together with the HEAD SHARD. self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -190,14 +198,14 @@ def __init__(self, self.n_groups = n_groups if n_groups % self.tp_size != 0: - # - for TP we shard conv_dim by sharding on n_groups, - # - but if n_groups cannot divide tp_size, we need to + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to # extend some extra groups - self.n_groups = n_groups + extra_groups_for_head_shards(n_groups, self.tp_size) + self.n_groups = n_groups + extra_groups_for_head_shards( + n_groups, self.tp_size) - self.conv_dim = ( - intermediate_size + 2 * self.n_groups * ssm_state_size - ) + self.conv_dim = (intermediate_size + + 2 * self.n_groups * ssm_state_size) self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, output_size=self.conv_dim, @@ -210,62 +218,76 @@ def __init__(self, # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size + self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config) + self.in_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size + + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config) - # - because in_proj is a concatenation of 3 weights, we + # - because in_proj is a concatenation of 3 weights, we # need to interleave them before sharding # - use the custom weight loader mamba_v2_sharded_weight_loader # for conv1d.bias, covn1d.weight and in_proj.weight # - need to set these settings, to assign the groups to the head shards group_shard_settings = ( - self.n_groups * self.ssm_state_size, # expected model size - (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned - self.num_heads // n_groups, # ratio for mapping back to original group + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + self.num_heads // + n_groups, # ratio for mapping back to original group ) intemediate_settings = (intermediate_size, 0, 1) head_setings = (self.num_heads, 0, 1) delattr(self.conv1d.bias, "weight_loader") - set_weight_attrs(self.conv1d.bias, { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intemediate_settings, group_shard_settings, group_shard_settings, - ], - self.tp_size, tp_rank, - ) - }) + set_weight_attrs( + self.conv1d.bias, { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intemediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }) delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs(self.conv1d.weight, { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intemediate_settings, group_shard_settings, group_shard_settings, - ], - self.tp_size, tp_rank - ) - }) + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + intemediate_settings, + group_shard_settings, + group_shard_settings, + ], self.tp_size, tp_rank) + }) delattr(self.in_proj.weight, "weight_loader") - set_weight_attrs(self.in_proj.weight, { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intemediate_settings, # for gate - intemediate_settings, group_shard_settings, group_shard_settings, - head_setings, # for dt - ], - self.tp_size, tp_rank - ) - }) - - # - these are TPed by heads to reduce the size of the + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intemediate_settings, # for gate + intemediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) + + # - these are TPed by heads to reduce the size of the # temporal shape self.A = nn.Parameter( torch.empty( - divide(num_heads, self.tp_size), dtype=torch.float32, + divide(num_heads, self.tp_size), + dtype=torch.float32, )) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) @@ -277,16 +299,14 @@ def __init__(self, set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - self.out_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=use_bias, - input_is_parallel=True, - quant_config=quant_config) + self.out_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config) - self.norm = Mixer2RMSNormGated( - intermediate_size // self.tp_size, eps=rms_norm_eps - ) + self.norm = Mixer2RMSNormGated(intermediate_size // self.tp_size, + eps=rms_norm_eps) def forward_native(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, @@ -297,27 +317,27 @@ def forward_cuda(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams): - seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size # - doing it differently from mixer v1; little confused with its logic - # - we need to do is to detect if there is any prefill; if there are + # - we need to do is to detect if there is any prefill; if there are # no prefils, then each example will be coming in one sample at a time - # - on the other hand v1 checks for "query_start_loc" and "context_lens_tensor" - # however we have noticed that, even when the samples are coming in - # one at a time, they are still non-NO.e + # - on the other hand v1 checks for "query_start_loc" + # and "context_lens_tensor" however we have noticed that, even + # when the samples are coming in + # one at a time, they are still not NONE, e.g., # * "query_start_loc" = [0, 1, ..] # * "context_lens_tensor" = [8, ...] - has_prefill = attn_metadata.num_prefills > 0 + has_prefill = attn_metadata.num_prefills > 0 # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) gate, hidden_states_B_C, dt = torch.split( projected_states, [ - self.intermediate_size // self.tp_size, - self.conv_dim // self.tp_size, + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, self.num_heads // self.tp_size, ], dim=-1, @@ -335,7 +355,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, # |-------------------- seq_len ---------------------| # |-- query_len ---| - # - "cache_indices" upates the conv_state cache in positions + # - "cache_indices" updates the conv_state cache in positions # pointed to by "mamba_cache_params.state_indices_tensor" hidden_states_B_C = causal_conv1d_fn( hidden_states_B_C.transpose(0, 1), @@ -345,8 +365,8 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc - ).transpose(0, 1)[:seq_len] + query_start_loc=attn_metadata.query_start_loc).transpose( + 0, 1)[:seq_len] else: hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, @@ -354,14 +374,13 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor - ) + conv_state_indices=mamba_cache_params.state_indices_tensor) # - get hidden_states, B and C after depthwise convolution. hidden_states, B, C = torch.split( hidden_states_B_C, [ - self.intermediate_size // self.tp_size, + self.intermediate_size // self.tp_size, groups_time_state_size // self.tp_size, groups_time_state_size // self.tp_size, ], @@ -370,12 +389,12 @@ def forward_cuda(self, hidden_states: torch.Tensor, # 3. State Space Model sequence transformation if has_prefill: - + # FIXME: we are having problems using mamba_chunk_scan_combined # with chunked prefill. This is because there is no # initial_states requires initial_states.shape[0] to match # the batch size, but cu_seqlens requires batch_size = 1. - # Therefore as of now, initial_states and cu_seqlens are + # Therefore as of now, initial_states and cu_seqlens are # mutually exclusive. initial_states = None @@ -385,7 +404,8 @@ def forward_cuda(self, hidden_states: torch.Tensor, # ] scan_output, varlen_state = mamba_chunk_scan_combined( - hidden_states.view(1, seq_len, self.num_heads // self.tp_size, self.head_dim), + hidden_states.view(1, seq_len, self.num_heads // self.tp_size, + self.head_dim), dt.unsqueeze(0), self.A, B.view(1, seq_len, self.n_groups // self.tp_size, -1), @@ -412,15 +432,17 @@ def forward_cuda(self, hidden_states: torch.Tensor, hidden_states = scan_output.view(seq_len, -1) else: - # NOTE: can be optimized? + # NOTE: can be optimized? n_groups = self.n_groups // self.tp_size - A = self.A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt = dt[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) B = B.view(-1, n_groups, B.shape[1] // n_groups) C = C.view(-1, n_groups, C.shape[1] // n_groups) - hidden_states_reshaped = hidden_states.view(-1, self.num_heads // self.tp_size, self.head_dim) + hidden_states_reshaped = hidden_states.view( + -1, self.num_heads // self.tp_size, self.head_dim) # - the hidden is reshaped into number of current batches # - in this case there is no more prefil, so the batches gen @@ -434,22 +456,21 @@ def forward_cuda(self, hidden_states: torch.Tensor, mamba_cache_params.ssm_state, hidden_states_reshaped, dt, - A, + A, B, C, - D, + D, z=None, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=mamba_cache_params.state_indices_tensor, ) hidden_states = hidden_states.view( - -1, (self.num_heads // self.tp_size) * self.head_dim - ) + -1, (self.num_heads // self.tp_size) * self.head_dim) # # 4. gated MLP hidden_states = self.norm(hidden_states, gate) # # 5. Final linear projection out, _ = self.out_proj(hidden_states) - return out \ No newline at end of file + return out diff --git a/vllm/model_executor/layers/mamba/ops/softplus.py b/vllm/model_executor/layers/mamba/ops/softplus.py index 5541655c66160..5ec75be51bf3b 100644 --- a/vllm/model_executor/layers/mamba/ops/softplus.py +++ b/vllm/model_executor/layers/mamba/ops/softplus.py @@ -1,15 +1,21 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/softplus.py + +# ruff: noqa: E501 + import triton import triton.language as tl from packaging import version TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") - if TRITON3: + @triton.jit def softplus(dt): return tl.math.log(tl.math.exp(dt) + 1) else: + @triton.jit def softplus(dt): - return tl.math.log1p(tl.exp(dt)) \ No newline at end of file + return tl.math.log1p(tl.exp(dt)) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 312a65769b634..3eba3c49b4590 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -1,51 +1,134 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_bmm.py +# ruff: noqa: E501,SIM102 """We want triton==2.1.0 or 2.2.0 for this """ import math -import torch -import torch.nn.functional as F +import torch import triton import triton.language as tl -from einops import rearrange, repeat - def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + return lambda nargs: [ + nargs[name].zero_() for name in names if nargs[name] is not None + ] @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), ], key=['chunk_size', 'K', 'IS_CAUSAL'], ) @triton.jit def _bmm_chunk_fwd_kernel( # Pointers to matrices - a_ptr, b_ptr, out_ptr, seq_idx_ptr, + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, # Matrix dimensions - seqlen, chunk_size, K, ngroups, - stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, - stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk, - stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn, - stride_seq_idx_batch, stride_seq_idx_seqlen, + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, # Meta-parameters IS_CAUSAL: tl.constexpr, dot_dtype: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_ch = tl.program_id(axis=2).to(tl.int64) @@ -65,14 +148,22 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_n[None, :] * stride_b_seqlen) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype) - b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype) + a = tl.load(a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & + (offs_n[None, :] < chunk_size_limit), + other=0.0).to(dot_dtype) acc += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -81,16 +172,30 @@ def _bmm_chunk_fwd_kernel( offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_SEQ_IDX: chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) - tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & + (offs_n[None, :] < chunk_size)) + -def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): +def _bmm_chunk_fwd(a, + b, + chunk_size, + seq_idx=None, + causal=False, + output_dtype=None): """ Argument: a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) @@ -117,20 +222,44 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=No nchunks = math.ceil(seqlen / chunk_size) # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype - out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size), - device=a.device, dtype=out_dtype) - dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else - (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32)) - grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), - batch, nchunks if not has_groups else nchunks * ngroups) + out = torch.empty( + (batch, nchunks, chunk_size, chunk_size) if not has_groups else + (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) + dot_dtype = (tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 + or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + chunk_size, META['BLOCK_SIZE_N']), batch, nchunks + if not has_groups else nchunks * ngroups) with torch.cuda.device(a.device.index): _bmm_chunk_fwd_kernel[grid]( - a, b, out, seq_idx, - seqlen, chunk_size, k, ngroups if has_groups else 1, - a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), - b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1), - out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), causal, dot_dtype, HAS_SEQ_IDX=seq_idx is not None, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 79fa52e0b8c4f..c538aaa464171 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -1,55 +1,175 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_chunk_scan.py +# ruff: noqa: E501 """We want triton==2.1.0 or 2.2.0 for this """ -from packaging import version - import torch - import triton import triton.language as tl +from packaging import version TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + return lambda nargs: [ + nargs[name].zero_() for name in names if nargs[name] is not None + ] @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), ], key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices - cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + prev_states_ptr, + D_ptr, # Matrix dimensions - chunk_size, hdim, dstate, - batch, seqlen, nheads_ngroups_ratio, + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, # Strides - stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, - stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, - stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, stride_D_head, # Meta-parameters IS_CAUSAL: tl.constexpr, @@ -57,7 +177,9 @@ def _chunk_scan_fwd_kernel( D_HAS_HDIM: tl.constexpr, HAS_Z: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, ): @@ -68,23 +190,31 @@ def _chunk_scan_fwd_kernel( num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + ( + pid_h // nheads_ngroups_ratio) * stride_cb_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_C_head prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, + mask=offs_m < chunk_size, + other=0.0).to(tl.float32) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=pid_c >= 1, + other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Without the if (pid_c > -1), with Triton 2.1.0, I get @@ -92,23 +222,40 @@ def _chunk_scan_fwd_kernel( # With Triton 2.2.0, this works if IS_TRITON_22 or pid_c > -1: # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * stride_states_hdim + + offs_k_dstate[:, None] * stride_states_dstate) if not HAS_SEQ_IDX: scale_m = tl.exp(dA_cs_m) else: scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) if BLOCK_SIZE_DSTATE <= 128: - C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K @@ -116,24 +263,36 @@ def _chunk_scan_fwd_kernel( acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) - x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim) dt_ptrs = dt_ptr + offs_k * stride_dt_csize dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + K_MAX = chunk_size_limit if not IS_CAUSAL else min( + (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) for k in range(0, K_MAX, BLOCK_SIZE_K): - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb = tl.load(cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: mask = offs_m[:, None] >= k + offs_k[None, :] cb = tl.where(mask, cb, 0.0) cb = cb.to(x_ptr.dtype.element_ty) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < hdim), + other=0.0) acc += tl.dot(cb, x) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k x_ptrs += BLOCK_SIZE_K * stride_x_seqlen @@ -145,28 +304,54 @@ def _chunk_scan_fwd_kernel( if HAS_D: if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, + mask=offs_n < hdim, + other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), - mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_n[None, :] < hdim), + other=0.0).to(tl.float32) acc += x_residual * D if HAS_Z: out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) - tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :]) + tl.store(out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head - z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) - z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim), + other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) - tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + -def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): +def _chunk_scan_fwd(cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape @@ -176,36 +361,88 @@ def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=Non if z is not None: assert z.shape == x.shape if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Allocates output. - out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + out = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) if z is not None: - out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + out_x = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) assert out_x.stride() == out.stride() else: out_x = None - grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) - z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) - if z is not None else (0, 0, 0, 0)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel[grid]( - cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D, - chunk_size, headdim, dstate, - batch, seqlen, nheads // ngroups, - cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - z_strides[0], z_strides[1], z_strides[2], z_strides[3], - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - C.stride(0), C.stride(1), C.stride(2), C.stride(3), - states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else + (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), D.stride(0) if D is not None else 0, True, D is not None, @@ -215,4 +452,4 @@ def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=Non HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, ) - return out, out_x \ No newline at end of file + return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 3184bbbf03d41..bafdcd2585e5a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -1,22 +1,24 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_chunk_state.py +# ruff: noqa: E501 """We want triton==2.1.0 or 2.2.0 for this """ import math -import torch -import torch.nn.functional as F +import torch import triton import triton.language as tl -from einops import rearrange, repeat - from .softplus import softplus def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + return lambda nargs: [ + nargs[name].zero_() for name in names if nargs[name] is not None + ] + @triton.autotune( configs=[ @@ -33,20 +35,37 @@ def init_to_zero(names): @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices - dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, # Matrix dimension - batch, seqlen, nheads, chunk_size, - dt_min, dt_max, + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, # Strides - stride_dt_batch, stride_dt_seqlen, stride_dt_head, + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, stride_A_head, stride_dt_bias_head, - stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, # Meta-parameters DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, ): pid_b = tl.program_id(axis=0) @@ -60,60 +79,165 @@ def _chunk_cumsum_fwd_kernel( offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) - dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + + offs_c[None, :] * stride_dt_seqlen) A_ptrs = A_ptr + offs_h * stride_A_head - dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize) - dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + + offs_c[None, :] * stride_dA_cs_csize) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + dt = tl.load(dt_ptrs, + mask=(offs_h[:, None] < nheads) & + (offs_c[None, :] < chunk_size_limit), + other=0.0).to(tl.float32) if HAS_DT_BIAS: - dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, + mask=offs_h < nheads, + other=0.0).to(tl.float32) dt += dt_bias[:, None] if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) - dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) - tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + 0.0) + tl.store(dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) - tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + tl.store(dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), ], key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit def _chunk_state_fwd_kernel( # Pointers to matrices - x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, # Matrix dimensions - hdim, dstate, chunk_size, - batch, seqlen, nheads_ngroups_ratio, + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, - stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch @@ -122,7 +246,8 @@ def _chunk_state_fwd_kernel( num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head @@ -132,30 +257,46 @@ def _chunk_state_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize if HAS_SEQ_IDX: seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) if HAS_SEQ_IDX: - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) if HAS_SEQ_IDX: - seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k else: - scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -170,40 +311,130 @@ def _chunk_state_fwd_kernel( states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) + @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), ], key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit def _chunk_state_varlen_kernel( # Pointers to matrices - x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr, + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, # Matrix dimensions - hdim, dstate, chunk_size, - seqlen, nheads_ngroups_ratio, + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, # Strides - stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_b_seqlen, stride_b_head, stride_b_dstate, - stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate, - stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) @@ -212,7 +443,8 @@ def _chunk_state_varlen_kernel( pid_n = tl.program_id(axis=0) % num_pid_n end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) pid_c = (end_idx - 1) // chunk_size - b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head @@ -221,10 +453,13 @@ def _chunk_state_varlen_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize chunk_size_limit = end_idx - pid_c * chunk_size @@ -233,12 +468,24 @@ def _chunk_state_varlen_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k) & + (offs_k[None, :] >= start_idx_cur - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate) & + (offs_k[:, None] >= start_idx_cur - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -249,8 +496,13 @@ def _chunk_state_varlen_kernel( # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk if start_idx < pid_c * chunk_size: - chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate) - chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + chunk_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + chunk_states = tl.load(chunk_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0) scale = tl.exp(dA_cs_last) acc += chunk_states * scale @@ -260,37 +512,77 @@ def _chunk_state_varlen_kernel( states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) -def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): +def _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): batch, seqlen, nheads = dt.shape - assert A.shape == (nheads,) + assert A.shape == (nheads, ) if dt_bias is not None: - assert dt_bias.shape == (nheads,) + assert dt_bias.shape == (nheads, ) nchunks = math.ceil(seqlen / chunk_size) - dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + dt_out = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + dA_cumsum = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, + triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( - dt, A, dt_bias, dt_out, dA_cumsum, - batch, seqlen, nheads, chunk_size, - dt_limit[0], dt_limit[1], - dt.stride(0), dt.stride(1), dt.stride(2), + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), A.stride(0), dt_bias.stride(0) if dt_bias is not None else 0, - dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out -def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True): + +def _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=None, + states=None, + states_in_fp32=True): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape @@ -304,24 +596,54 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_f assert states.shape == (batch, nchunks, nheads, headdim, dstate) else: states_dtype = torch.float32 if states_in_fp32 else B.dtype - states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) - grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) + states = torch.empty((batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype) + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( + dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_state_fwd_kernel[grid]( - x, B, states, dt, dA_cumsum, seq_idx, - headdim, dstate, chunk_size, - batch, seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - B.stride(0), B.stride(1), B.stride(2), B.stride(-1), - states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_SEQ_IDX=seq_idx is not None, ) return states + def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): total_seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape @@ -333,19 +655,47 @@ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert chunk_states.shape == (nchunks, nheads, headdim, dstate) - states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device) - grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), - batch, nheads) + states = torch.empty(batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) with torch.cuda.device(x.device.index): _chunk_state_varlen_kernel[grid]( - x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states, - headdim, dstate, chunk_size, - total_seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), - B.stride(0), B.stride(1), B.stride(2), - dt.stride(1), dt.stride(0), dt.stride(2), - dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2), - chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3), - states.stride(0), states.stride(1), states.stride(2), states.stride(3), + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), ) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 728024a6b31fa..90854fd0c0a10 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -1,50 +1,67 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_combined.py +# ruff: noqa: E501 """We want triton==2.1.0 or 2.2.0 for this """ -from packaging import version - import torch - import triton - from einops import rearrange +from packaging import version from .ssd_bmm import _bmm_chunk_fwd -from .ssd_chunk_state import _chunk_cumsum_fwd -from .ssd_chunk_state import _chunk_state_fwd -from .ssd_chunk_state import chunk_state_varlen -from .ssd_state_passing import _state_passing_fwd from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, + chunk_state_varlen) +from .ssd_state_passing import _state_passing_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + return lambda nargs: [ + nargs[name].zero_() for name in names if nargs[name] is not None + ] -def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + +def _mamba_chunk_scan_combined_fwd(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) assert x.shape == (batch, seqlen, nheads, headdim) assert dt.shape == (batch, seqlen, nheads) - assert A.shape == (nheads,) + assert A.shape == (nheads, ) assert C.shape == B.shape if z is not None: assert z.shape == x.shape if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() - if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + if x.stride(-1) != 1 and x.stride( + 1) != 1: # Either M or K dimension should be contiguous x = x.contiguous() - if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + if z is not None and z.stride(-1) != 1 and z.stride( + 1) != 1: # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() @@ -54,28 +71,73 @@ def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, d # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) - states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + dA_cumsum, dt = _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + states = _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=seq_idx, + states_in_fp32=True) # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) - states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], - initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, - seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype) - states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None else None, + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=C.dtype) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) - CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) - out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) + CB = _bmm_chunk_fwd(C, + B, + chunk_size, + seq_idx=seq_idx, + output_dtype=torch.float32) + out, out_x = _chunk_scan_fwd(CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx) if cu_seqlens is None: return out, out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0), + varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), + dt.squeeze(0), dA_cumsum.squeeze(0), cu_seqlens, states.squeeze(0)) return out, out_x, dt, dA_cumsum, states, final_states, varlen_states -def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): + +def mamba_chunk_scan_combined(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -99,9 +161,26 @@ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bia cu_seqlens = None else: assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" - out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit) if not return_varlen_states: return out if not return_final_states else (out, final_states) else: varlen_states = rest[0] - return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) \ No newline at end of file + return (out, + varlen_states) if not return_final_states else (out, + final_states, + varlen_states) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 59ed1d17cfda2..dfc87fc7e5c68 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -1,10 +1,11 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_state_passing.py +# ruff: noqa: E501 """We want triton==2.1.0 or 2.2.0 for this """ import torch - import triton import triton.language as tl @@ -23,16 +24,37 @@ @triton.jit def _state_passing_fwd_kernel( # Pointers to matrices - states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, # Matrix dimensions - dim, nchunks, seqlen, chunk_size, + dim, + nchunks, + seqlen, + chunk_size, # Strides - stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim, - stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, - stride_final_states_batch, stride_final_states_head, stride_final_states_dim, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, - stride_initstates_batch, stride_initstates_head, stride_initstates_dim, - stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, # Meta-parameters HAS_INITSTATES: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, @@ -59,16 +81,20 @@ def _state_passing_fwd_kernel( states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) else: initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim - states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk seq_idx = 0 for c in range(nchunks): - new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + new_states = tl.load(states_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale = tl.exp(dA_cs) if HAS_SEQ_IDX: - seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) + seq_idx_new = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) seq_idx = seq_idx_new states = scale * states + new_states @@ -81,7 +107,11 @@ def _state_passing_fwd_kernel( out_ptrs += stride_out_chunk -def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, +def _state_passing_fwd(states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, out_dtype=None): batch, nchunks, nheads, dim = states.shape assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) @@ -92,20 +122,44 @@ def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=Non seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype - out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) - final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32) + out = torch.empty((batch, nchunks, nheads, dim), + device=states.device, + dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), + device=states.device, + dtype=torch.float32) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( - states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx, - dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, - states.stride(0), states.stride(1), states.stride(2), states.stride(3), - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - final_states.stride(0), final_states.stride(1), final_states.stride(2), - dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), - *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) - if initial_states is not None else (0, 0, 0)), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + states, + out, + final_states, + dA_chunk_cumsum, + initial_states, + seq_idx, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size if seq_idx is not None else 0, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_chunk_cumsum.stride(0), + dA_chunk_cumsum.stride(2), + dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2)) if initial_states is not None else + (0, 0, 0)), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_INITSTATES=initial_states is not None, HAS_SEQ_IDX=seq_idx is not None, ) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 5c6a8ab043170..2693c45b27520 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -10,16 +10,16 @@ from vllm.attention.layer import Attention from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -67,6 +67,7 @@ def forward(self, x): x, _ = self.down_proj(x) return x + class BambaMixerDecoderLayer(nn.Module): def __init__(self, @@ -161,7 +162,7 @@ def __init__( max_position_embeddings=max_position_embeddings, base=rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), # see impl of get_rope + dtype=torch.get_default_dtype(), # see impl of get_rope ) self.qkv_proj = QKVParallelLinear( @@ -203,23 +204,28 @@ def self_attention( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # because the bamba model may potentially handle long sequences, - # we should adjust the sin_cos cache if necesary to avoid out of bounds + # because the bamba model may potentially handle long sequences, + # we should adjust the sin_cos cache if necessary to avoid out of bounds # - first get the max_position max_position = max( getattr(attn_metadata, 'max_prefill_seq_len', 0), getattr(attn_metadata, 'max_decode_seq_len', 0), ) if max_position == 0: - # if we cannot get the max lenght from the metadata, then - # get it frmo the positions + # if we cannot get the max length from the metadata, then + # get it from the positions max_position = positions.max().item() - if self.rotary_emb.max_position_embeddings <= max_position: + # when VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 could potentially cause inputs + # longer than max_position_embeddings. We extend the rope cache + # to prevent CUDA errors. Be aware that the outputs could be of + # lower quality for long sequence lengths. + rotary = self.rotary_emb + if rotary.max_position_embeddings <= max_position: # we set it to the next power of two that covers it - while self.rotary_emb.max_position_embeddings <= max_position: - self.rotary_emb.max_position_embeddings *= 2 - self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache() + while rotary.max_position_embeddings <= max_position: + rotary.max_position_embeddings *= 2 + rotary.cos_sin_cache = rotary._compute_cos_sin_cache() q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) @@ -260,6 +266,7 @@ def forward( "mamba": BambaMixerDecoderLayer } + class BambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -312,10 +319,11 @@ def forward( # add additional attn_metadata for the mixer layers if attn_metadata.num_prefills > 0: sed_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate(zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): sed_idx[srt:end] = i attn_metadata.seq_idx = sed_idx @@ -335,7 +343,8 @@ def forward( layer_mamba_cache_params = None if isinstance(layer, BambaMixerDecoderLayer): - layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i - num_attn) + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) hidden_states, residual = layer( positions=positions, @@ -457,18 +466,14 @@ def _get_mamba_cache_shape( intermediate_size = self.config.mamba_expand * hidden_size - # if n_groups is not divisible by world_size, need to extend the shards to ensure - # all groups needed by a head is sharded along with it - n_groups = ( - self.config.mamba_n_groups + - extra_groups_for_head_shards(self.config.mamba_n_groups, world_size) - ) + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) # - heads and n_groups are TP-ed - conv_dim = ( - intermediate_size + - 2 * n_groups * self.config.mamba_d_state - ) + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) conv_state_shape = ( divide(conv_dim, world_size), self.config.mamba_d_conv - 1,