From 62181d5981f877d91218d118cc5f9e40ea33d9d2 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 5 Dec 2024 08:55:04 +0800 Subject: [PATCH 01/19] initial pr without tp fix Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_bamba.py | 326 +++ .../layers/mamba/mamba_mixer2.py | 300 +++ .../layers/mamba/ops/softplus.py | 15 + .../layers/mamba/ops/ssd_bmm.py | 262 +++ .../layers/mamba/ops/ssd_chunk_scan.py | 1829 +++++++++++++++++ .../layers/mamba/ops/ssd_chunk_state.py | 988 +++++++++ .../layers/mamba/ops/ssd_combined.py | 481 +++++ .../layers/mamba/ops/ssd_state_passing.py | 348 ++++ vllm/model_executor/models/bamba.py | 543 +++++ vllm/model_executor/models/registry.py | 1 + 10 files changed, 5093 insertions(+) create mode 100644 tests/models/decoder_only/language/test_bamba.py create mode 100644 vllm/model_executor/layers/mamba/mamba_mixer2.py create mode 100644 vllm/model_executor/layers/mamba/ops/softplus.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_bmm.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_combined.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_state_passing.py create mode 100644 vllm/model_executor/models/bamba.py diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py new file mode 100644 index 0000000000000..f5ae20de63a8a --- /dev/null +++ b/tests/models/decoder_only/language/test_bamba.py @@ -0,0 +1,326 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. + +This actually is really indentical to test_mamba, so maybe we can reuse + +Run `pytest tests/models/decoder_only/language/test_bamba.py`. +""" +import pytest +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vllm.sampling_params import SamplingParams +from vllm.worker.model_runner import _get_graph_batch_size + +from ...utils import check_outputs_equal + +# will be ch +MODELS = ["ibm-fms/Bamba-9.8b-1.8T-hf"] + + +# Use lower-level interfaces to create this greedy generator, as mamba will +# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. +def generate_greedy(model_name, example_prompts, max_tokens): + # Create a text generation pipeline + # - in the original test_mamba.py they do not put the model to cuda + # maybe this affects the test. + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name) + + # Generate texts from the prompts + outputs = [] + for prompt in example_prompts: + # Tokenize the input prompt with truncation + inputs = tokenizer(prompt, return_tensors="pt", truncation=True) + input_ids = inputs["input_ids"] + + # Generate text using the model's generate method directly + generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) + generated_text = tokenizer.decode(generated_ids[0], + skip_special_tokens=True) + + outputs.append((generated_ids[0].tolist(), generated_text)) + + return outputs + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + hf_outputs = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_batching( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + for_loop_outputs = [] + with vllm_runner(model, dtype=dtype) as vllm_model: + for prompt in example_prompts: + for_loop_outputs.append( + vllm_model.generate_greedy([prompt], max_tokens)[0]) + + batched_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=for_loop_outputs, + outputs_1_lst=batched_outputs, + name_0="for_loop_vllm", + name_1="batched_vllm", + ) + +@pytest.mark.skip("bamba does not support chunked prefill yet") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, + model: str, dtype: str, + max_tokens: int) -> None: + # Tests chunked prefill in conjunction with n>1. In this case, prefill is + # populated with decoding tokens and we test that it doesn't fail. + # This test might fail if cache is not allocated correctly for n > 1 + # decoding steps inside a chunked prefill forward pass (where we have both + # prefill and decode together ) + sampling_params = SamplingParams(n=3, + temperature=1, + seed=0, + max_tokens=max_tokens) + with vllm_runner( + model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=30, + max_num_seqs=10 # forces prefill chunks with decoding + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + +@pytest.mark.skip("bamba does not support chunked prefill yet") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, + max_tokens: int, + chunked_prefill_token_size: int) -> None: + """ + Checks exact match decode between huggingface model and vllm runner with + chunked prefill. + """ + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size + + non_chunked = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + chunked = vllm_model.generate_greedy(example_prompts, + max_tokens=max_tokens) + + check_outputs_equal( + outputs_0_lst=chunked, + outputs_1_lst=non_chunked, + name_0="chunked", + name_1="non_chunked", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [15]) +def test_parallel_sampling( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + + with vllm_runner(model, dtype=dtype) as vllm_model: + for_loop_outputs = [] + for _ in range(10): + for_loop_outputs.append( + # using example_prompts index 1 instead of 0 since with 0 the + # logprobs get really close and the test doesn't pass + vllm_model.generate_greedy([example_prompts[1]], max_tokens) + [0]) + sampling_params = SamplingParams(n=10, + temperature=0.001, + seed=0, + max_tokens=max_tokens) + n_lt_1_outputs = vllm_model.generate([example_prompts[1]], + sampling_params) + token_ids, texts = n_lt_1_outputs[0] + n_lt_1_outputs = [(token_id, text) + for token_id, text in zip(token_ids, texts)] + + check_outputs_equal( + outputs_0_lst=n_lt_1_outputs, + outputs_1_lst=for_loop_outputs, + name_0="vllm_n_lt_1_outputs", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_mamba_cache_cg_padding( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # This test is for verifying that mamba cache is padded to CG captured + # batch size. If it's not, a torch RuntimeError will be raised because + # tensor dimensions aren't compatible + while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + example_prompts.append(example_prompts[0]) + + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + except RuntimeError: + pytest.fail( + "Couldn't run batch size which is not equal to a Cuda Graph " + "captured batch size. " + "Could be related to mamba cache not padded correctly") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_models_preemption_recompute( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # Tests that outputs are identical with and w/o preemtions (recompute) + assert dtype == "float" + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Mamba inner state management doesn't + # collapse in case where the number of incoming requests and + # finished_requests_ids is larger than the maximum Mamba block capacity. + # This could generally happen due to the fact that Mamba does support + # statelessness mechanism where it can cleanup new incoming requests in + # a single step. + try: + with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 100, 10) + except ValueError: + pytest.fail("Mamba inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily ") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_state_cleanup( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Mamba state is cleaned up between + # steps, If its not cleaned, an error would be expected. + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + for _ in range(10): + vllm_model.generate_greedy([example_prompts[0]] * 100, 1) + except ValueError: + pytest.fail("Mamba inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_multistep( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 10, 1) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_multistep_correctness(vllm_runner, model: str, dtype: str, + max_tokens: int, example_prompts) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_outputs_multistep = vllm_model.generate_greedy( + example_prompts, max_tokens) + + with vllm_runner(model, num_scheduler_steps=1, + max_num_seqs=2) as vllm_model: + vllm_outputs_single_step = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_multistep, + outputs_1_lst=vllm_outputs_single_step, + name_0="vllm_outputs_multistep", + name_1="vllm_outputs_single_step", + ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py new file mode 100644 index 0000000000000..f1c114ac9d4c6 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -0,0 +1,300 @@ +import torch +from torch import nn +from torch.nn.parameter import Parameter + +# Added by the IBM Team, 2024 + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) + +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + + +from typing import Tuple, Union, Optional +from vllm.model_executor.custom_op import CustomOp + +# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated +@CustomOp.register("mixer2_gated_rms_norm") +class Mixer2RMSNormGated(CustomOp): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + pass + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + from vllm import _custom_ops as ops + + # the original code casted gate to float32 before silu + # hidden_states * nn.functional.silu(gate.to(torch.float32)) + out = torch.empty_like(x) + ops.rms_norm( + out, + x * nn.functional.silu(gate), + self.weight.data, + self.variance_epsilon, + ) + return out + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer2") +class MambaMixer2(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + time_step_rank: int, + use_conv_bias: bool, + use_bias: bool, + use_rms_norm: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation="silu", + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.time_step_rank = time_step_rank + self.ssm_state_size = ssm_state_size + self.use_rms_norm = use_rms_norm + self.activation = activation + + self.chunk_size = 256 + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_heads = num_heads + self.n_groups = n_groups + self.conv_dim = intermediate_size + 2 * n_groups * ssm_state_size + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config) + + # unlike mamba_mixer.py (v1), we do not TP the A matrix as it is + # already quite small. + # - same for dt_bias and D + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + param.data.copy_(-torch.exp(loaded_weight.float())) + + self.A = nn.Parameter( + torch.empty( + num_heads, + dtype=torch.float32, + )) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.dt_bias = nn.Parameter(torch.ones(num_heads)) + self.D = nn.Parameter(torch.ones(num_heads)) + + self.out_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config) + + self.norm = Mixer2RMSNormGated( + intermediate_size, eps=rms_norm_eps + ) + + def forward_native(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams): + + + seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + # - doing it differently from mixer v1; little confused with its logic + # - we need to do is to detect if there is any prefill; if there are + # no prefils, then each example will be coming in one sample at a time + # - on the other hand v1 checks for "query_start_loc" and "context_lens_tensor" + # however we have noticed that, even when the samples are coming in + # one at a time, they are still non-NO.e + # * "query_start_loc" = [0, 1, ..] + # * "context_lens_tensor" = [8, ...] + has_prefill = attn_metadata.num_prefills > 0 + + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + gate, hidden_states_B_C, dt = torch.split( + projected_states, + [self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + ) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if has_prefill: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # - "cache_indices" upates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + hidden_states_B_C = causal_conv1d_fn( + hidden_states_B_C.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc + ).transpose(0, 1)[:seq_len] + else: + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor + ) + + # - get hidden_states, B and C after depthwise convolution. + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. State Space Model sequence transformation + if has_prefill: + + # FIXME: we are having problems using mamba_chunk_scan_combined + # with chunked prefill. This is because there is no + # initial_states requires initial_states.shape[0] to match + # the batch size, but cu_seqlens requires batch_size = 1. + # Therefore as of now, initial_states and cu_seqlens are + # mutually exclusive. + + initial_states = None + # if any(attn_metadata.context_lens_tensor > 0): + # initial_states = mamba_cache_params.ssm_state[ + # mamba_cache_params.state_indices_tensor + # ] + + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states.view(1, seq_len, -1, self.head_dim), + dt.unsqueeze(0), + self.A, + B.view(1, seq_len, self.n_groups, -1), + C.view(1, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + seq_idx=attn_metadata.seq_idx.unsqueeze(0), + cu_seqlens=attn_metadata.query_start_loc, + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + for i, idx in enumerate(mamba_cache_params.state_indices_tensor): + mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) + + # - reshape + hidden_states = scan_output.view(seq_len, -1) + else: + + # NOTE: can be optimized? + A = self.A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(-1, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(-1, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(-1, self.num_heads, self.head_dim) + + # - the hidden is reshaped into number of current batches + # - in this case there is no more prefil, so the batches gen + # 1 token at a time + # - thus hidden will be (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using "mamba_cache_params.state_indices_tensor", just as + # above in the prefill case + + hidden_states = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor, + ) + hidden_states = hidden_states.view(-1, self.num_heads * self.head_dim) + + # # 4. gated MLP + hidden_states = self.norm(hidden_states, gate) + + # # 5. Final linear projection + out, _ = self.out_proj(hidden_states) + return out \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/ops/softplus.py b/vllm/model_executor/layers/mamba/ops/softplus.py new file mode 100644 index 0000000000000..5541655c66160 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/softplus.py @@ -0,0 +1,15 @@ +import triton +import triton.language as tl +from packaging import version + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + + +if TRITON3: + @triton.jit + def softplus(dt): + return tl.math.log(tl.math.exp(dt) + 1) +else: + @triton.jit + def softplus(dt): + return tl.math.log1p(tl.exp(dt)) \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py new file mode 100644 index 0000000000000..48fd4f063e779 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -0,0 +1,262 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, b_ptr, out_ptr, seq_idx_ptr, + # Matrix dimensions + seqlen, chunk_size, K, ngroups, + stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk, + stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'K'], +) +@triton.jit +def _bmm_chunk_bwd_kernel( + # Pointers to matrices + a_ptr, dout_ptr, db_ptr, res_ptr, + # Matrix dimensions + seqlen, chunk_size, K, ngroups, + stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, + stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n, + stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k, + stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k, + # Meta-parameters + dot_dtype: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(K, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_cs = tl.arange(0, BLOCK_SIZE_CS) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m) + a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype) + a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype) + acc += tl.dot(dout, a) + dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m + a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_RESIDUAL: + res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head + res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k) + res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32) + acc += res + db = acc.to(db_ptr.dtype.element_ty) + + db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head + db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k) + tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)) + + +def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, dtype=out_dtype) + dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), + batch, nchunks if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, b, out, seq_idx, + seqlen, chunk_size, k, ngroups if has_groups else 1, + a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), + b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1), + out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out + + +def _bmm_chunk_bwd(a, dout, residual=None, out=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + Return: + out: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + + If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be + zeroed out before calling this function. + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + nchunks, chunk_size = dout.shape[1], dout.shape[-1] + if a.stride(-1) != 1 and a.stride(-2) != 1: + a = a.contiguous() + if dout.stride(-1) != 1 and dout.stride(-2) != 1: + dout = dout.contiguous() + if residual is not None: + assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k) + if residual.stride(-1) != 1 and residual.stride(1) != 1: + residual = residual.contiguous() + # Allocates output. + if out is not None: + assert out.shape == a.shape + assert out.stride(-1) == 1 or out.stride(1) == 1 + else: + out = torch.empty_like(a) + dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch, + nchunks if not has_groups else nchunks * ngroups) + residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2), + residual.stride(-1)) + if residual is not None else (0, 0, 0, 0)) + with torch.cuda.device(a.device.index): + _bmm_chunk_bwd_kernel[grid]( + a, dout, out, residual, + seqlen, chunk_size, k, ngroups if has_groups else 1, + a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), + dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1), + out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1), + residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3], + dot_dtype, + HAS_RESIDUAL=residual is not None, + ) + return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py new file mode 100644 index 0000000000000..e77ed026907ac --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -0,0 +1,1829 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +from packaging import version + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or pid_c > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) + tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_scan_fwd_kernel_wip( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_D_head, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_n = tl.program_id(axis=0) + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE) + + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + # if pid_c == 0: + # if pid_b == 0: + # if pid_h == 0: + # tl.device_print("", prev_states) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # scale_m = tl.exp(dA_cs_m) + # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) + # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # cb *= dt_m + # mask = offs_m[:, None] >= offs_m[None, :] + # cb = tl.where(mask, cb, 0.0) + # cb = cb.to(x_ptr.dtype.element_ty) + # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0) + # acc += tl.dot(cb, x) + # if HAS_D: + # if D_HAS_HDIM: + # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + # else: + # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + # acc += x.to(tl.float32) * D + # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M): + start_m = tl.multiple_of(start_m, BLOCK_SIZE_M) + dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0) + acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) + # cb *= dt_m + # mask = offs_m[:, None] >= offs_m[None, :] + # cb = tl.where(mask, cb, 0.0) + # cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0) + # acc += tl.dot(cb, x) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + acc += x.to(tl.float32) * D + + # if HAS_Z: + # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) + # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + # acc *= z * tl.sigmoid(z) + + tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim)) + + # TODO: this is not correct, and quite a bit slower + if start_m + BLOCK_SIZE_M < chunk_size_limit: + # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32) + B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0) + dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32) + # TODO: seq_idx + scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m + # B *= scale + B = B.to(x_ptr.dtype.element_ty) + tmp = tl.dot(B, x) + prev_states += tmp.to(prev_states.dtype) + + C_ptrs += BLOCK_SIZE_M * stride_C_seqlen + B_ptrs += BLOCK_SIZE_M * stride_B_seqlen + cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_M * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_M * stride_dt_csize + out_ptrs += BLOCK_SIZE_M * stride_out_seqlen + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32}), + triton.Config({'BLOCK_SIZE_M': 64}), + triton.Config({'BLOCK_SIZE_M': 128}), + triton.Config({'BLOCK_SIZE_M': 256}), + ], + key=["chunk_size", "hdim"], +) +@triton.jit +def _chunk_scan_bwd_dz_kernel( + # Pointers to matrices + dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_D_head, + stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim, + stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim, + stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_DDACS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head + if RECOMPUTE_OUTPUT: + outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head + if HAS_DDACS: + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + if HAS_D: + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim) + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim) + dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim) + if RECOMPUTE_OUTPUT: + outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim) + if HAS_D: + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + if RECOMPUTE_OUTPUT: + outz = out * z * z_sigmoid + tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + dout *= z * z_sigmoid + tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + if HAS_D: + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + dD = tl.sum(dout * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + dD = tl.sum(dout * x) + tl.store(dD_ptr, dD) + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out -= x * D + if HAS_DDACS: + ddA_cs = tl.sum(dout * out, axis=1) + tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_scan_bwd_dstates_kernel( + # Pointers to matrices + dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + batch, seqlen, nchunks, nheads_ngroups_ratio, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate, + stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen) + c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale_k = tl.exp(dA_cs_k) + else: + seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) + scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0) + dout = (dout * scale_k).to(dout_ptr.dtype.element_ty) + c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0) + acc += tl.dot(dout, c) + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + c_ptrs += BLOCK_SIZE_K * stride_c_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + out = acc.to(dprev_states_ptr.dtype.element_ty) + + dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate) + tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_dc_kernel( + # Pointers to matrices + dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, + dc_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split + prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + if HAS_DDA_CS: + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + prev_states = prev_states.to(dout_ptrs.dtype.element_ty) + dc = tl.dot(dout, prev_states) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_m) + else: + scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + dc *= scale[:, None] + if HAS_DDA_CS: + ddA_cs = tl.sum(dc * c, axis=1) + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + acc += dc + dout_ptrs += stride_dout_head + prev_states_ptrs += stride_prev_states_head + dA_cumsum_ptrs += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptrs += stride_ddA_cs_head + # if HAS_SEQ_IDX: + # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate) + tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr, + dx_ptr, ddt_ptr, # dD_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_D_head, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + # if HAS_D: + # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Idk why limiting K_MAX gives wrong results, is it a Triton bug? + # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + K_MAX = chunk_size_limit + for k in range(0, K_MAX, BLOCK_SIZE_K): + # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, + # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. + # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. + # This will cause NaN in acc, and hence NaN in dx and ddt. + mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + # if HAS_D: + # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) + # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32) + # dD = tl.sum(x * dout, axis=0) + # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) + + +# Disabling HAS_DDA_CS for now since it's much slower +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim'], +) +# @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) +# @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32}) +@triton.jit +def _chunk_scan_bwd_dcb_kernel( + # Pointers to matrices + x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + dcb_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + if HAS_DDA_CS: + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + return + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + dcb = tl.dot(dout, x) + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + dcb *= dt_n + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) + dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + if HAS_DDA_CS: + tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") + ddA_cs = dcb * cb + mask = offs_m[:, None] >= offs_n[None, :] + 1 + ddA_cs = tl.where(mask, ddA_cs, 0.0) + ddA_cs = tl.cumsum(ddA_cs, axis=1) + ddA_cs = tl.where(mask, ddA_cs, 0.0) + ddA_cs = tl.sum(ddA_cs, axis=0) + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) + tl.store(ddA_cumsum_ptr, 0.0) + acc += dcb + dout_ptrs += stride_dout_head + x_ptrs += stride_x_head + dt_ptrs += stride_dt_head + dA_cumsum_ptr += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptr += stride_ddA_cs_head + ddA_cumsum_ptrs += stride_ddA_cs_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + mask = offs_m[:, None] >= offs_n[None, :] + acc = tl.where(mask, acc, 0.0) + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +# Not numerically stable and should not be used. Leaving here for reference. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32}), + triton.Config({'BLOCK_SIZE_M': 64}), + triton.Config({'BLOCK_SIZE_M': 128}), + triton.Config({'BLOCK_SIZE_M': 256}), + ], + key=["chunk_size", "hdim"], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_unstable_kernel( + # Pointers to matrices + dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr, + ddA_cumsum_ptr, dD_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_D_head, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + SUBTRACT_DDTDT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + if HAS_D: + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + if HAS_D: + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + dD = tl.sum(dout * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + dD = tl.sum(dout * x) + tl.store(dD_ptr, dD) + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out -= x * D + ddA_cs = tl.sum(dout * out, axis=1) + if SUBTRACT_DDTDT: + dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddA_cs -= dt * ddt + tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_stable_kernel_old( + # Pointers to matrices + x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, + ddAcs_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + # Doing a matmul loop with cumsum later on will cause Triton to crash + # Instead we do just one big matmul + # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # for k in range(0, hdim, BLOCK_SIZE_K): + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) + # acc += tl.dot(dout, x) + # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim + # x_ptrs += BLOCK_SIZE_K * stride_x_hdim + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + acc = tl.dot(dout, x) + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) + acc *= cb + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + acc *= dt_n + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + mask = offs_m[:, None] >= offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + acc = tl.cumsum(acc, axis=1) + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n + tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) + tl.store(ddAcs_ptr, 0.0) + + # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64) + # offs_k = tl.arange(0, BLOCK_SIZE_K) + # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + # dt_ptrs = dt_ptr + offs_n * stride_dt_csize + # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + + # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m + # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n + # for n in range(0, chunk_size_limit_n, 64): + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0) + # acc = tl.dot(dout, x) + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32) + # acc *= cb + # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) + # acc *= dt_n + # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) + # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n + # acc = tl.where(mask, acc, 0.0) + # acc = tl.cumsum(acc, axis=1) + # acc = tl.where(mask, acc, 0.0) + # ddA_cs = tl.sum(acc, axis=0) + # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n) + # # tl.store(ddAcs_ptr, 0.0) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + tl.store(ddA_cumsum_ptr, 0.0) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower + lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M + # lo, hi = 0, chunk_size + for start_n in range(lo, hi, BLOCK_SIZE_N): + start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) + # Doing a matmul loop with cumsum later on will cause Triton to crash + # Instead we do just one big matmul + # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # for k in range(0, hdim, BLOCK_SIZE_K): + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) + # acc += tl.dot(dout, x) + # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim + # x_ptrs += BLOCK_SIZE_K * stride_x_hdim + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) + acc = tl.dot(dout, x) + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= dt_n + # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) + acc *= cb + dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + rowsum_new = rowsum + tl.sum(acc, axis=1) + acc = rowsum[:, None] + tl.cumsum(acc, axis=1) + rowsum = rowsum_new + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) + x_ptrs += BLOCK_SIZE_N * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_N * stride_dt_csize + cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + # Need to zero out the rest, since we'll be summing the rows together + for start_n in range(hi, chunk_size, BLOCK_SIZE_N): + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1) + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_prev_kernel( + # Pointers to matrices + dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nchunks, nheads_ngroups_ratio, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + prev_states = prev_states.to(dout_ptrs.dtype.element_ty) + acc = tl.dot(dout, prev_states) + c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + ddA_cs = tl.sum(acc * c, axis=1) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_m) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + ddA_cs *= scale + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + + +def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Allocates output. + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + ) + return out, out_x + + +def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert B.shape == C.shape + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Allocates output. + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel_wip[grid]( + cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + D.stride(0) if D is not None else 0, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + BLOCK_SIZE_M=128, + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out, out_x + + +def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False): + batch, seqlen, nheads, headdim = x.shape + assert z.shape == x.shape + assert out.shape == x.shape + assert dout.shape == out.shape + nchunks = math.ceil(seqlen / chunk_size) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + if has_ddAcs: + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + if D is not None: + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + if dz is not None: + assert dz.shape == z.shape + else: + dz = torch.empty_like(z) + if recompute_output: + outz = torch.empty_like(x) + dout_x = torch.empty_like(dout) + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dz_kernel[grid_dz]( + dout, out, z, x, D, outz if recompute_output else None, + dz, dout_x, dD, ddA_cumsum if has_ddAcs else None, + chunk_size, headdim, + batch, seqlen, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + z.stride(0), z.stride(1), z.stride(2), z.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + D.stride(0) if D is not None else 0, + *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)), + dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3), + dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + if has_ddAcs else (0, 0, 0, 0)), + D is not None, + D.dim() == 2 if D is not None else True, + has_ddAcs, + BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), + RECOMPUTE_OUTPUT=recompute_output, + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD) + return return_vals if not recompute_output else (*return_vals, outz) + + +def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): + batch, seqlen, nheads, headdim = dout.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + dtype = C.dtype if dtype is None else dtype + dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype) + grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(C.device.index): + _chunk_scan_bwd_dstates_kernel[grid_dstates]( + dout, C, dprev_states, dA_cumsum, seq_idx, + headdim, dstate, chunk_size, + batch, seqlen, nchunks, nheads // ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return dprev_states + + +def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1): + batch, nchunks, nheads, headdim, dstate = prev_states.shape + _, seqlen, _, _ = dout.shape + _, _, _, chunk_size = dA_cumsum.shape + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == (batch, seqlen, nheads, headdim) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if C is not None: + assert C.shape == (batch, seqlen, ngroups, dstate) + C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) + ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3)) + else: + C_strides = (0, 0, 0, 0) + ddA_cumsum_prev = None + ddA_cumsum_prev_strides = (0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32) + grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(dout.device.index): + _chunk_scan_bwd_dc_kernel[grid_dc]( + dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev, + chunk_size, dstate, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + *C_strides, + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), + *ddA_cumsum_prev_strides, + HAS_DDA_CS=ddA_cumsum_prev is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dC = dC.sum(2) + return dC if C is None else (dC, ddA_cumsum_prev) + + +def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if CB is not None: + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4)) + BLOCK_SIZE_M_min = 16 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4)) + else: + CB_strides = (0, 0, 0, 0, 0) + ddA_cumsum = None + ddA_cumsum_strides = (0, 0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32) + grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dcb_kernel[grid_dcb]( + x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + *CB_strides, + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5), + *ddA_cumsum_strides, + HAS_DDA_CS=ddA_cumsum is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dcb = dcb.sum(2) + if ddA_cumsum is not None: + BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return dcb if CB is None else (dcb, ddA_cumsum) + + +def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + # if D is not None: + # BLOCK_SIZE_M_min = 32 + # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32) + # else: + # dD = None + dx = torch.empty_like(x) + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dx_kernel[grid_dx]( + x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + D.stride(0) if D is not None else 0, + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, + D is not None, + D.dim() == 2 if D is not None else True, + ) + # if D is not None: + # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + return dx, ddt.to(dtype=dt.dtype) + + +def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): + """Not numerically stable and should not be used. Leaving here for reference. + """ + + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert ddt.shape == dt.shape + assert out.shape == x.shape + assert dout.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + ddA_cumsum = torch.empty_like(dt) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + if D is not None: # Triton gives wrong results if we write to the same location + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs]( + dout, out, dt, ddt, x, D, ddA_cumsum, dD, + chunk_size, headdim, + batch, seqlen, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + D.stride(0) if D is not None else 0, + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + D is not None, + D.dim() == 2 if D is not None else True, + subtract_ddtdt, + BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return ddA_cumsum, dD + + +def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + assert dA_cumsum.shape == dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + BLOCK_SIZE_M_min = 16 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs]( + x, dout, dt, dA_cumsum, cb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16), + ) + BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum + + +def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + assert dA_cumsum.shape == dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + BLOCK_SIZE_M_min = 32 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( + x, dout, dt, dA_cumsum, cb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum + + +def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None): + batch, nchunks, nheads, headdim, dstate = prev_states.shape + _, seqlen, _, _ = dout.shape + _, _, _, chunk_size = dA_cumsum.shape + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == (batch, seqlen, nheads, headdim) + ngroups = C.shape[2] + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(dout.device.index): + _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs]( + dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev, + chunk_size, dstate, headdim, + batch, seqlen, nchunks, nheads // ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3), + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + return ddA_cumsum_prev + + +class ChunkScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + CB = _bmm_chunk_fwd(C, B, chunk_size) + out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z) + ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z) + return out + + @staticmethod + def backward(ctx, dout): + if dout.stride(-1) != 1: + dout = dout.contiguous() + out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert dout.shape == (batch, seqlen, nheads, headdim) + if z is not None: + dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D) + else: + dz = None + dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype) + dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups) + dC = dC.to(C.dtype) + dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups) + dCB = dCB.to(CB.dtype) + dB = _bmm_chunk_bwd(C, dCB) + dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC) + dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D) + # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. + # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt + if z is not None: + ddA_cumsum -= ddt * dt + else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz + ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D) + ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype) + return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz + + +def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1. + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z) + + +def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), + rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + prev_states.to(C.dtype)) * state_decay_out + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return out if z is None else out * F.silu(z) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py new file mode 100644 index 0000000000000..af14bb9fb8022 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -0,0 +1,988 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +from .softplus import softplus + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + # Matrix dimension + batch, seqlen, nheads, chunk_size, + dt_min, dt_max, + # Strides + stride_dt_batch, stride_dt_seqlen, stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + pid_c = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) + tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_bwd_kernel( + # Pointers to matrices + ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr, + ddt_ptr, dA_ptr, ddt_bias_ptr, + # Matrix dimensions + batch, seqlen, nheads, chunk_size, + dt_min, dt_max, + # Strides + stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize, + stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize, + stride_dt_batch, stride_dt_seqlen, stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head, + stride_dA_head, + stride_ddt_bias_head, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + pid_c = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk + ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize) + ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + ddt = ddA * A[:, None] + ddt_out + dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt_presoftplus = dt + dt = tl.where(dt <= 20.0, softplus(dt), ddt) + clamp_mask = (dt < dt_min) | (dt > dt_max) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) + ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0) + ddt = tl.where(clamp_mask, 0.0, ddt) + if DT_SOFTPLUS: + ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt) + tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + dA = tl.sum(ddA * dt, axis=1) + tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) + if HAS_DT_BIAS: + ddt_bias = tl.sum(ddt, axis=1) + tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_state_bwd_dx_kernel( + # Pointers to matrices + x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, + dx_ptr, ddt_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) + if BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) + else: + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + ddA_cs = -(ddt * dt_m) + ddA_cs_last = -tl.sum(ddA_cs) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) + + dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty) + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_state_bwd_db_kernel( + # Pointers to matrices + x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + db_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + if HAS_DDA_CS: + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + dstates = dstates.to(x_ptrs.dtype.element_ty) + db = tl.dot(x, dstates) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + db *= (scale * dt_m)[:, None] + if HAS_DDA_CS: + # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum + ddA_cs = tl.sum(db * b, axis=1) + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + acc += db + x_ptrs += stride_x_head + dstates_ptrs += stride_states_head + dt_ptrs += stride_dt_head + dA_cumsum_ptr += stride_dA_cs_head + dA_cumsum_ptrs += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptrs += stride_ddA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # if HAS_SEQ_IDX: + # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0) + db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate) + tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_state_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) + if BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) + else: + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + acc *= scale[:, None] + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + # ddA_cs = -(ddt * dt_m) + # Triton 2.2.0 errors if we have the cumsum here, so we just write it out + # then call torch.cumsum outside this kernel. + # ddA_cs = tl.cumsum(ddt * dt_m) + ddA_cs = ddt * dt_m + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + seqlen, nheads_ngroups_ratio, + # Strides + stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate, + stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + if start_idx < pid_c * chunk_size: + chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate) + chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0) + scale = tl.exp(dA_cs_last) + acc += chunk_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, A, dt_bias, dt_out, dA_cumsum, + batch, seqlen, nheads, chunk_size, + dt_limit[0], dt_limit[1], + dt.stride(0), dt.stride(1), dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None): + batch, seqlen, nheads = dt.shape + _, _, nchunks, chunk_size = ddA.shape + assert ddA.shape == (batch, nheads, nchunks, chunk_size) + assert ddt_out.shape == (batch, nheads, nchunks, chunk_size) + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) + else: + ddt_bias = None + if ddt is not None: + assert ddt.shape == dt.shape + else: + ddt = torch.empty_like(dt) + dA = torch.empty_like(A, dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_bwd_kernel[grid_chunk_cs]( + ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias, + batch, seqlen, nheads, chunk_size, + dt_limit[0], dt_limit[1], + ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3), + ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3), + dt.stride(0), dt.stride(1), dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + ddt.stride(0), ddt.stride(1), ddt.stride(2), + dA.stride(0), + ddt_bias.stride(0) if ddt_bias is not None else 0, + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return ddt, dA, ddt_bias + + +def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, B, states, dt, dA_cumsum, seq_idx, + headdim, dstate, chunk_size, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if dx is not None: + assert dx.shape == x.shape + else: + dx = torch.empty_like(x) + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_dx_kernel[grid_dx]( + x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + ) + return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype) + + +def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + dstate = dstates.shape[-1] + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B is not None: + assert B.shape == (batch, seqlen, ngroups, dstate) + B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3)) + # Use torch.empty since the Triton kernel will call init_to_zero + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + else: + B_strides = (0, 0, 0, 0) + ddA_cumsum = None + ddA_cumsum_strides = (0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32) + grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_db_kernel[grid_db]( + x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum, + chunk_size, dstate, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + *B_strides, + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4), + *ddA_cumsum_strides, + HAS_DDA_CS=ddA_cumsum is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dB = dB.sum(2) + if ddA_cumsum is not None: + # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute + # to the state of the chunk. + # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) + # But it's easier to just do the cumsum for all elements, the result will be the same. + torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum) + return dB if B is None else (dB, ddA_cumsum) + + +def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Use torch.empty since the Triton kernel will call init_to_zero + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs]( + x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + ) + torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) + return ddA_cumsum + + +def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states, + headdim, dstate, chunk_size, + total_seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), + B.stride(0), B.stride(1), B.stride(2), + dt.stride(1), dt.stride(0), dt.stride(2), + dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2), + chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), + ) + return states + + +class ChunkStateFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if B.stride(-1) != 1: + B = B.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32) + ctx.save_for_backward(B, x, dt, dA_cumsum) + return states + + @staticmethod + def backward(ctx, dstates): + B, x, dt, dA_cumsum = ctx.saved_tensors + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if dstates.stride(-1) != 1: + dstates = dstates.contiguous() + dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates) + dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups) + dB = dB.to(B.dtype) + return dB, dx, ddt, ddA_cumsum, None + + +def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32) + + +def chunk_state_ref(B, x, dt, dA_cumsum): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + ngroups = B.shape[2] + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) + B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) + decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py new file mode 100644 index 0000000000000..a6fb60c199667 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -0,0 +1,481 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +from packaging import version + +import torch + +import triton +import triton.language as tl + +from einops import rearrange + +from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd +from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd +from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db +from .ssd_chunk_state import chunk_state_varlen +from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd +from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates +from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb +from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_scan_chunk_state_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr, + b_ptr, dstates_ptr, + dx_ptr, ddt_ptr, dD_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_D_head, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + # However, we're getting error with the Triton compiler 2.1.0 for that code path: + # Unexpected mma -> mma layout conversion + # Triton 2.2.0 fixes this + offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate) + if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) * scale[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate + acc *= scale[:, None] + + # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + # dt_ptrs = dt_ptr + offs_m * stride_dt_csize + # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + # ddt = tl.sum(acc * x, axis=1) * dt_m + # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit + K_MIN = pid_m * BLOCK_SIZE_M + cb_ptrs += K_MIN * stride_cb_csize_k + dout_ptrs += K_MIN * stride_dout_seqlen + dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize + for k in range(K_MIN, K_MAX, BLOCK_SIZE_K): + k = tl.multiple_of(k, BLOCK_SIZE_K) + # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, + # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. + # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. + # This will cause NaN in acc, and hence NaN in dx and ddt. + mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + dD = tl.sum(dout_res * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + else: + dD = tl.sum(dout_res * x) + tl.store(dD_ptr, dD) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + +def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + if dx is None: + dx = torch.empty_like(x) + else: + assert dx.shape == x.shape + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( + x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + D.stride(0) if D is not None else 0, + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + D is not None, + D.dim() == 2 if D is not None else True, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + IS_TRITON_22=TRITON_22 + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return dx, ddt.to(dtype=dt.dtype), dD + +def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size) + # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) + # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) + # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) + states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype) + states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] + # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0), + cu_seqlens, states.squeeze(0)) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, + dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, + dt_limit=(0.0, float("inf")), + dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False): + if dout.stride(-1) != 1: + dout = dout.contiguous() + batch, seqlen, nheads, headdim = x.shape + nchunks = math.ceil(seqlen / chunk_size) + _, _, ngroups, dstate = B.shape + assert dout.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert C.shape == B.shape + assert out.shape == x.shape + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if dx is not None: + assert dx.shape == x.shape + if dB is not None: + assert dB.shape == B.shape + dB_given = dB + else: + dB_given = torch.empty_like(B) + if dC is not None: + assert dC.shape == C.shape + dC_given = dC + else: + dC_given = torch.empty_like(C) + if dz is not None: + assert z is not None + assert dz.shape == z.shape + if ddt is not None: + assert ddt.shape == dt.shape + ddt_given = ddt + else: + ddt_given = torch.empty_like(dt) + # TD: For some reason Triton (2.1.0 and 2.2.0) errors with + # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why. + dt_in = dt.clone() + dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, + dt_limit=dt_limit) + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + if z is not None: + dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output) + outz = rest[0] if recompute_output else out + else: + dz = None + outz = out + dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype) + # dstates has length nchunks, containing the gradient to initial states at index 0 and + # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1) + # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states + # will be used in matmul in the next kernels. + dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + rearrange(dstates, "... p n -> ... (p n)"), + dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None, + seq_idx=seq_idx, + has_initial_states=initial_states is not None, + dstates_dtype=x.dtype, + states_dtype=x.dtype, + chunk_size=chunk_size, + ) + # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and + # gradient to the final states at index (nchunks - 1) + # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1) + # The final states is not stored. + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate) + dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None + dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx) + # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups) + dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) + # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) + dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups) + # Computing ddA with the dcb kernel is much slower, so we're not using it for now + dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) + # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups) + dCB = dCB.to(CB.dtype) + _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given) + _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given) + # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate + # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16 + if z is None: + dD = dD_from_x + # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. + # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt + # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might + # be a lot of underflow. + + # This is already done as part of bwd_dC kernel + # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx) + ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum + ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1]) + # This is already done as part of bwd_dB kernel + # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx) + # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j] + ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB) + ddA += ddA_next + ddA_prev + + ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given) + + # These 2 lines are just to test ddt and dA being computed by old code + # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z) + # ddt_given.copy_(ddt) + + return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states) + return return_vals if not recompute_output else (*return_vals, outz) + +class MambaChunkScanCombinedFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): + ctx.dt_dtype = dt.dtype + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) + ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx) + ctx.dt_softplus = dt_softplus + ctx.chunk_size = chunk_size + ctx.dt_limit = dt_limit + ctx.return_final_states = return_final_states + ctx.return_varlen_states = return_varlen_states + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) + + @staticmethod + def backward(ctx, dout, *args): + out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors + assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward" + dfinal_states = args[0] if ctx.return_final_states else None + dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) + return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None + +def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states) \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py new file mode 100644 index 0000000000000..63863b8236e1c --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -0,0 +1,348 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, + # Matrix dimensions + dim, nchunks, seqlen, chunk_size, + # Strides + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim, + stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, + stride_final_states_batch, stride_final_states_head, stride_final_states_dim, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_initstates_batch, stride_initstates_head, stride_initstates_dim, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_bwd_kernel( + # Pointers to matrices + dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr, + dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr, + # Matrix dimensions + dim, nchunks, seqlen, chunk_size, + # Strides + stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim, + stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, + stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim, + # Meta-parameters + CONVERT_STATES: tl.constexpr, + HAS_DFINAL_STATES: tl.constexpr, + HAS_DINITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk + ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk + dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk + if CONVERT_STATES: + states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk + if HAS_DFINAL_STATES: + dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head + if HAS_DINITSTATES: + dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + dout_ptrs = dout_ptr + offs_m * stride_dout_dim + if CONVERT_STATES: + states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim + + if HAS_DFINAL_STATES: + dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + tl.store(dstates_ptrs, dstates, mask=offs_m < dim) + if HAS_SEQ_IDX: + seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen) + dstates_ptrs -= stride_dstates_chunk + for c in range(nchunks - 1): + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen)) + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + seq_idx = seq_idx_new + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if CONVERT_STATES: + tl.store(states_converted_ptrs, out, mask=offs_m < dim) + ddA = tl.sum(out * dstates) * scale + tl.store(ddA_cs_ptr, ddA) + dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dstates = scale * dstates + dout + tl.store(dstates_ptrs, dstates, mask=offs_m < dim) + dout_ptrs -= stride_dout_chunk + dstates_ptrs -= stride_dstates_chunk + dA_cs_ptr -= stride_dA_cs_chunk + ddA_cs_ptr -= stride_ddA_cs_chunk + out_ptrs -= stride_out_chunk + if CONVERT_STATES: + states_converted_ptrs -= stride_out_chunk + if CONVERT_STATES: + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + tl.store(states_converted_ptrs, out, mask=offs_m < dim) + if not HAS_DINITSTATES: + tl.store(ddA_cs_ptr, 0.0) + else: + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + scale = tl.where(seq_idx == 0, scale, 0.0) + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + ddA = tl.sum(out * dstates) * scale + tl.store(ddA_cs_ptr, ddA) + dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dstates = scale * dstates + dout + tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim) + + +def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, + out_dtype=None): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + assert initial_states.shape == (batch, nheads, dim) + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx, + dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, + states.stride(0), states.stride(1), states.stride(2), states.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + final_states.stride(0), final_states.stride(1), final_states.stride(2), + dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) + if initial_states is not None else (0, 0, 0)), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out, final_states + + +def _state_passing_bwd( + states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None, + dstates_dtype=None, states_dtype=None, chunk_size=None +): + """ + states contains the initial_states at index 0. The final states are not included in states. + """ + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + assert dout.shape == (batch, nchunks, nheads, dim) + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) + if states_dtype is not None and states_dtype != states.dtype: + states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) + assert states_converted.stride() == states.stride() + else: + states_converted = None + if has_initial_states: + dinitstates = torch.empty_like(dstates[:, 0]) + else: + dinitstates = None + if dfinal_states is not None: + assert dfinal_states.shape == (batch, nheads, dim) + BLOCK_SIZE_min = 64 + n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min + ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks, + dtype=torch.float32, device=dA_chunk_cumsum.device) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(dout.device.index): + _state_passing_bwd_kernel[grid]( + dout, states, dA_chunk_cumsum, dfinal_states, seq_idx, + dstates, ddA_chunk_cumsum, dinitstates, states_converted, + dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), + dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), + *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2)) + if dfinal_states is not None else (0, 0, 0)), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), + ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1), + *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2)) + if dinitstates is not None else (0, 0, 0)), + CONVERT_STATES=states_converted is not None, + HAS_DFINAL_STATES=dfinal_states is not None, + HAS_DINITSTATES=dinitstates is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"] + n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype) + if states_dtype is not None and states_dtype == states.dtype: + states_converted = states + return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted) + + +class StatePassingFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, states, dA_chunk_cumsum, initial_states=None): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if states.stride(-1) != 1: + states = states.contiguous() + out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states) + ctx.save_for_backward(out, dA_chunk_cumsum) + ctx.has_initial_states = initial_states is not None + return out, final_states + + @staticmethod + def backward(ctx, dout, dfinal_states): + out, dA_chunk_cumsum = ctx.saved_tensors + batch, nchunks, nheads, dim = out.shape + assert dout.shape == (batch, nchunks, nheads, dim) + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + assert dfinal_states.shape == (batch, nheads, dim) + if dout.stride(-1) != 1: + dout = dout.contiguous() + dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd( + out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states + ) + return dstates, ddA_chunk_cumsum, dinitstates + + +def state_passing(states, dA_chunk_cumsum, initial_states=None): + """ + Argument: + states: (batch, nchunks, nheads, dim) + dA_chunk_cumsum: (batch, nheads, nchunks) + initial_states: (batch, nheads, dim) + Return: + out: (batch, nchunks, nheads, dim) + final_states: (batch, nheads, dim) + """ + return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states) + + +def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): + """ + Argument: + states: (batch, nchunks, nheads, dim) + dA_chunk_cumsum: (batch, nheads, nchunks) + initial_states: (batch, nheads, dim) + Return: + out: (batch, nchunks, nheads, dim) + final_states: (batch, nheads, dim) + """ + if initial_states is None: + initial_states = torch.zeros_like(states[:, 0]) + states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1) + dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0)) + dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1) + nchunks = dA_chunk_cumsum.shape[-1] + # (batch, nheads, nchunks, nchunks) + dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :] + # (batch, nheads, nchunks, nchunks) + decay_chunk = torch.exp(dt_chunk_segment_sum) + causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0) + decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) + out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states) + return out[:, :-1], out[:, -1] diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py new file mode 100644 index 0000000000000..e200ea485718d --- /dev/null +++ b/vllm/model_executor/models/bamba.py @@ -0,0 +1,543 @@ +"""Inference-only Bamba model.""" +# Added by the IBM Team, 2024 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import BambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) + +from .interfaces import HasInnerState, SupportsLoRA +from .utils import maybe_prefix + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class BambaMLP(nn.Module): + + def __init__( + self, + config: BambaConfig, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + +class BambaMixerDecoderLayer(nn.Module): + + def __init__(self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.mamba = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + time_step_rank = config.mamba_dt_rank, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + use_rms_norm=True, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + quant_config=quant_config) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class BambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.rotary_emb = RotaryEmbedding( + head_size=self.head_dim, + rotary_dim=config.attn_rotary_emb, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # because the bamba model may potentially handle long sequences, + # we should adjust the sin_cos cache if necesary to avoid out of bounds + # - first get the max_position + max_position = max( + getattr(attn_metadata, 'max_prefill_seq_len', 0), + getattr(attn_metadata, 'max_decode_seq_len', 0), + ) + if max_position == 0: + # if we cannot get the max lenght from the metadata, then + # get it frmo the positions + max_position = positions.max().item() + + if self.rotary_emb.max_position_embeddings <= max_position: + # we set it to the next power of two that covers it + while self.rotary_emb.max_position_embeddings <= max_position: + self.rotary_emb.max_position_embeddings *= 2 + self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache() + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": BambaAttentionDecoderLayer, + "mamba": BambaMixerDecoderLayer +} + +class BambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] + decoder_layers.append( + layer_class(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{i}")) + self.layers = nn.ModuleList(decoder_layers) + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # add additional attn_metadata for the mixer layers + if attn_metadata.num_prefills > 0: + sed_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate(zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + sed_idx[srt:end] = i + + attn_metadata.seq_idx = sed_idx + + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = None + if isinstance(layer, BambaAttentionDecoderLayer): + kv_cache = kv_caches[num_attn] + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance(layer, BambaMixerDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + mamba_cache_params=layer_mamba_cache_params) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Bamba currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = BambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + + layers_type = self.config.layers_block_type + num_mamba_layers = sum( + [layer_type == "mamba" for layer_type in layers_type]) + + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_params, + inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + conv_dim = ( + intermediate_size + + 2 * self.config.mamba_n_groups * self.config.mamba_d_state + ) + conv_state_shape = ( + conv_dim // world_size, + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + self.config.mamba_n_heads, + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c66fbce018a62..44b89d9744bd4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -38,6 +38,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), # ChatGLMModel supports multimodal "CohereForCausalLM": ("commandr", "CohereForCausalLM"), From 51bc78c504b849d55247dd3c066f2eb36536dd24 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 5 Dec 2024 03:21:57 +0000 Subject: [PATCH 02/19] fix casting in rms norm gated Signed-off-by: Yu Chin Fabian Lim --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index f1c114ac9d4c6..ecb743613361c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -47,12 +47,12 @@ def forward_cuda( from vllm import _custom_ops as ops - # the original code casted gate to float32 before silu - # hidden_states * nn.functional.silu(gate.to(torch.float32)) + # cast gate to float32 before silu out = torch.empty_like(x) + y = x * nn.functional.silu(gate.to(torch.float32)) ops.rms_norm( out, - x * nn.functional.silu(gate), + y.to(x.dtype), self.weight.data, self.variance_epsilon, ) From 81b93b40933a9423a9c9acf0cca88e35fa457875 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 5 Dec 2024 06:15:38 +0000 Subject: [PATCH 03/19] TP fix Signed-off-by: Yu Chin Fabian Lim --- .../layers/mamba/mamba_mixer2.py | 211 +++++++++++++++--- vllm/model_executor/models/bamba.py | 20 +- 2 files changed, 197 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index ecb743613361c..b2a4b2aaefc78 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -18,12 +18,17 @@ mamba_chunk_scan_combined) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.distributed import (divide, get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + tensor_model_parallel_all_reduce) +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, sharded_weight_loader, LoaderFunction) - -from typing import Tuple, Union, Optional +from typing import Tuple, Union, Optional, List from vllm.model_executor.custom_op import CustomOp # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated +# also referenced https://github.com/vllm-project/vllm/pull/9292 @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): def __init__(self, hidden_size, eps=1e-6): @@ -31,13 +36,31 @@ def __init__(self, hidden_size, eps=1e-6): self.hidden_size = hidden_size self.variance_epsilon = eps self.weight = nn.Parameter(torch.ones(hidden_size)) + self.tp_size = get_tensor_model_parallel_world_size() + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) def forward_native( self, x: torch.Tensor, gate: torch.Tensor, ): - pass + input_dtype = x.dtype + x = x * nn.functional.silu(gate.to(torch.float32)) + + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = (global_sums / count) + + else: + variance = x.pow(2).mean(-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * x.to(input_dtype) def forward_cuda( self, @@ -45,9 +68,12 @@ def forward_cuda( gate: torch.Tensor, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.tp_size > 1: + return self.forward_native(x, gate) + from vllm import _custom_ops as ops - # cast gate to float32 before silu + # cast x and gate to float32 before silu out = torch.empty_like(x) y = x * nn.functional.silu(gate.to(torch.float32)) ops.rms_norm( @@ -58,6 +84,57 @@ def forward_cuda( ) return out +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the extra (logical) groups to account for head shards""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + return tp_size - ngroups % tp_size + +def mamba_v2_sharded_weight_loader( + shard_spec: List[int], tp_size: int, tp_rank: int, +) -> LoaderFunction: + """Create a weight loader for mamba v2. This ensures that the projections are + correctly sharded so that they can be split into x, B, C. It also ensures the + the all the groups corresponding to a head shard is placed together with it. + """ + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + + # - track boundary of (sharded) param, and loaded_weight, respectively + boundary, loaded_boundary = 0, 0 + for full_dim, extra, ratio in shard_spec: + # - full dim is the expected size of the model + # - if extra > 0, this means there was some expansion + + # - num of dims expected to be loaded + shard_size = full_dim // tp_size + + # - compute where to take the loaded shard from + rank = tp_rank // ratio + + # - should start from here (determined by rank) + loaded_skip = rank * shard_size # take these number dims from loaded + loaded_start_idx = loaded_boundary + loaded_skip + + # - these many number dims to take from loaded_weight + take = min(shard_size, full_dim - extra - loaded_skip) + + # - always shard on dim 0 + param.data[ + boundary:boundary+take,... + ] = loaded_weight[ + loaded_start_idx:loaded_start_idx+take + ] + + # move boundaries + boundary += shard_size + loaded_boundary += (full_dim - extra) + + return loader + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register("mamba_mixer2") class MambaMixer2(CustomOp): @@ -76,7 +153,6 @@ def __init__(self, ssm_state_size: int, conv_kernel_size: int, intermediate_size: int, - time_step_rank: int, use_conv_bias: bool, use_bias: bool, use_rms_norm: bool, @@ -87,7 +163,22 @@ def __init__(self, activation="silu", quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.time_step_rank = time_step_rank + + # For TP, the sharding plan is as follows: + # - for the conv modules, since + # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, + # we shard intermediate_size and n_groups + # - since intermediate_size = n_heads * head_dim, sharding on + # intermediate_size is achieved by sharding on n_heads. + # - so if world_size divides groups, then sharding + # (n_groups / world_size, n_heads / world_size) + # also maintains the invariant n_heads % n_groups == 0 + # - HOWEVER< if world_size DOES NOT divide groups, then we need to allocate + # extra space in the shard, such that the WHOLE GROUP must be placed + # together with the HEAD SHARD. + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + self.ssm_state_size = ssm_state_size self.use_rms_norm = use_rms_norm self.activation = activation @@ -96,8 +187,17 @@ def __init__(self, self.intermediate_size = intermediate_size self.head_dim = head_dim self.num_heads = num_heads + self.n_groups = n_groups - self.conv_dim = intermediate_size + 2 * n_groups * ssm_state_size + if n_groups % self.tp_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + self.n_groups = n_groups + extra_groups_for_head_shards(n_groups, self.tp_size) + + self.conv_dim = ( + intermediate_size + 2 * self.n_groups * ssm_state_size + ) self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, output_size=self.conv_dim, @@ -116,22 +216,66 @@ def __init__(self, bias=use_bias, quant_config=quant_config) - # unlike mamba_mixer.py (v1), we do not TP the A matrix as it is - # already quite small. - # - same for dt_bias and D - - def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): - param.data.copy_(-torch.exp(loaded_weight.float())) + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups to the head shards + group_shard_settings = ( + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned + self.num_heads // n_groups, # ratio for mapping back to original group + ) + intemediate_settings = (intermediate_size, 0, 1) + head_setings = (self.num_heads, 0, 1) + + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs(self.conv1d.bias, { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intemediate_settings, group_shard_settings, group_shard_settings, + ], + self.tp_size, tp_rank, + ) + }) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs(self.conv1d.weight, { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intemediate_settings, group_shard_settings, group_shard_settings, + ], + self.tp_size, tp_rank + ) + }) + + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs(self.in_proj.weight, { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intemediate_settings, # for gate + intemediate_settings, group_shard_settings, group_shard_settings, + head_setings, # for dt + ], + self.tp_size, tp_rank + ) + }) + # - these are TPed by heads to reduce the size of the + # temporal shape self.A = nn.Parameter( torch.empty( - num_heads, - dtype=torch.float32, + divide(num_heads, self.tp_size), dtype=torch.float32, )) - set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) - self.dt_bias = nn.Parameter(torch.ones(num_heads)) - self.D = nn.Parameter(torch.ones(num_heads)) + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) self.out_proj = RowParallelLinear( intermediate_size, @@ -141,7 +285,7 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): quant_config=quant_config) self.norm = Mixer2RMSNormGated( - intermediate_size, eps=rms_norm_eps + intermediate_size // self.tp_size, eps=rms_norm_eps ) def forward_native(self, hidden_states: torch.Tensor, @@ -171,7 +315,11 @@ def forward_cuda(self, hidden_states: torch.Tensor, projected_states, _ = self.in_proj(hidden_states) gate, hidden_states_B_C, dt = torch.split( projected_states, - [self.intermediate_size, self.conv_dim, self.num_heads], + [ + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, + self.num_heads // self.tp_size, + ], dim=-1, ) @@ -212,7 +360,11 @@ def forward_cuda(self, hidden_states: torch.Tensor, # - get hidden_states, B and C after depthwise convolution. hidden_states, B, C = torch.split( hidden_states_B_C, - [self.intermediate_size, groups_time_state_size, groups_time_state_size], + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], dim=-1, ) @@ -233,11 +385,11 @@ def forward_cuda(self, hidden_states: torch.Tensor, # ] scan_output, varlen_state = mamba_chunk_scan_combined( - hidden_states.view(1, seq_len, -1, self.head_dim), + hidden_states.view(1, seq_len, self.num_heads // self.tp_size, self.head_dim), dt.unsqueeze(0), self.A, - B.view(1, seq_len, self.n_groups, -1), - C.view(1, seq_len, self.n_groups, -1), + B.view(1, seq_len, self.n_groups // self.tp_size, -1), + C.view(1, seq_len, self.n_groups // self.tp_size, -1), chunk_size=self.chunk_size, D=self.D, z=None, @@ -261,13 +413,14 @@ def forward_cuda(self, hidden_states: torch.Tensor, else: # NOTE: can be optimized? + n_groups = self.n_groups // self.tp_size A = self.A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt = dt[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) - B = B.view(-1, self.n_groups, B.shape[1] // self.n_groups) - C = C.view(-1, self.n_groups, C.shape[1] // self.n_groups) - hidden_states_reshaped = hidden_states.view(-1, self.num_heads, self.head_dim) + B = B.view(-1, n_groups, B.shape[1] // n_groups) + C = C.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states_reshaped = hidden_states.view(-1, self.num_heads // self.tp_size, self.head_dim) # - the hidden is reshaped into number of current batches # - in this case there is no more prefil, so the batches gen @@ -290,7 +443,9 @@ def forward_cuda(self, hidden_states: torch.Tensor, dt_softplus=True, state_batch_indices=mamba_cache_params.state_indices_tensor, ) - hidden_states = hidden_states.view(-1, self.num_heads * self.head_dim) + hidden_states = hidden_states.view( + -1, (self.num_heads // self.tp_size) * self.head_dim + ) # # 4. gated MLP hidden_states = self.norm(hidden_states, gate) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index e200ea485718d..a12ee30798c68 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -9,14 +9,15 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -83,7 +84,6 @@ def __init__(self, conv_kernel_size = config.mamba_d_conv, intermediate_size = config.mamba_expand *\ config.hidden_size, - time_step_rank = config.mamba_dt_rank, use_conv_bias = config.mamba_conv_bias, use_bias = config.mamba_proj_bias, use_rms_norm=True, @@ -459,12 +459,20 @@ def _get_mamba_cache_shape( intermediate_size = self.config.mamba_expand * hidden_size + # if n_groups is not divisible by world_size, need to extend the shards to ensure + # all groups needed by a head is sharded along with it + n_groups = ( + self.config.mamba_n_groups + + extra_groups_for_head_shards(self.config.mamba_n_groups, world_size) + ) + + # - heads and n_groups are TP-ed conv_dim = ( intermediate_size + - 2 * self.config.mamba_n_groups * self.config.mamba_d_state + 2 * n_groups * self.config.mamba_d_state ) conv_state_shape = ( - conv_dim // world_size, + divide(conv_dim, world_size), self.config.mamba_d_conv - 1, ) @@ -472,7 +480,7 @@ def _get_mamba_cache_shape( # - they are typically small # e.g., (h_heads, d_head, d_state) = (128, 64, 128) temporal_state_shape = ( - self.config.mamba_n_heads, + divide(self.config.mamba_n_heads, world_size), self.config.mamba_d_head, self.config.mamba_d_state, ) From 0f93e4aed932e45111af140a4917fa6f966eee86 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sun, 8 Dec 2024 09:31:45 +0000 Subject: [PATCH 04/19] fix mamba scan invalid address Signed-off-by: Yu Chin Fabian Lim --- vllm/model_executor/layers/mamba/ops/ssd_bmm.py | 2 +- vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py | 2 +- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 48fd4f063e779..1a4ddb13811c7 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -48,7 +48,7 @@ def _bmm_chunk_fwd_kernel( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_b = tl.program_id(axis=1) - pid_ch = tl.program_id(axis=2) + pid_ch = tl.program_id(axis=2).to(tl.int64) pid_c = pid_ch // ngroups pid_h = pid_ch - pid_c * ngroups num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index e77ed026907ac..c1fabf0ac5904 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -67,7 +67,7 @@ def _chunk_scan_fwd_kernel( BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, ): - pid_bc = tl.program_id(axis=1) + pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index af14bb9fb8022..5116735d2840b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -49,7 +49,10 @@ def _chunk_cumsum_fwd_kernel( BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, ): pid_b = tl.program_id(axis=0) - pid_c = tl.program_id(axis=1) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) pid_h = tl.program_id(axis=2) dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk @@ -191,7 +194,7 @@ def _chunk_state_fwd_kernel( HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): - pid_bc = tl.program_id(axis=1) + pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) From 742ae799c898dd03d0465797e56436641980ba84 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 05:31:23 +0000 Subject: [PATCH 05/19] some fixes and remove unused kernels Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_bamba.py | 4 +- .../layers/mamba/ops/ssd_bmm.py | 124 -- .../layers/mamba/ops/ssd_chunk_scan.py | 1613 +---------------- .../layers/mamba/ops/ssd_chunk_state.py | 640 ------- .../layers/mamba/ops/ssd_combined.py | 406 +---- .../layers/mamba/ops/ssd_state_passing.py | 236 --- vllm/model_executor/models/bamba.py | 6 +- 7 files changed, 21 insertions(+), 3008 deletions(-) diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py index f5ae20de63a8a..a3bcb644baf8b 100644 --- a/tests/models/decoder_only/language/test_bamba.py +++ b/tests/models/decoder_only/language/test_bamba.py @@ -7,8 +7,8 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams -from vllm.worker.model_runner import _get_graph_batch_size from ...utils import check_outputs_equal @@ -205,7 +205,7 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size(len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 1a4ddb13811c7..312a65769b634 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -90,76 +90,6 @@ def _bmm_chunk_fwd_kernel( out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2), - ], - key=['chunk_size', 'K'], -) -@triton.jit -def _bmm_chunk_bwd_kernel( - # Pointers to matrices - a_ptr, dout_ptr, db_ptr, res_ptr, - # Matrix dimensions - seqlen, chunk_size, K, ngroups, - stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, - stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n, - stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k, - stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k, - # Meta-parameters - dot_dtype: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_ch = tl.program_id(axis=2) - pid_c = pid_ch // ngroups - pid_h = pid_ch - pid_c * ngroups - num_pid_n = tl.cdiv(K, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - - a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head - dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_cs = tl.arange(0, BLOCK_SIZE_CS) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m) - a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)): - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype) - a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype) - acc += tl.dot(dout, a) - dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m - a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_RESIDUAL: - res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head - res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k) - res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32) - acc += res - db = acc.to(db_ptr.dtype.element_ty) - - db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head - db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k) - tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)) - - def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): """ Argument: @@ -206,57 +136,3 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=No HAS_SEQ_IDX=seq_idx is not None, ) return out - - -def _bmm_chunk_bwd(a, dout, residual=None, out=None): - """ - Argument: - a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) - residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - Return: - out: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - - If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be - zeroed out before calling this function. - """ - # Check constraints. - has_groups = a.dim() == 4 - if not has_groups: - batch, seqlen, k = a.shape - else: - batch, seqlen, ngroups, k = a.shape - nchunks, chunk_size = dout.shape[1], dout.shape[-1] - if a.stride(-1) != 1 and a.stride(-2) != 1: - a = a.contiguous() - if dout.stride(-1) != 1 and dout.stride(-2) != 1: - dout = dout.contiguous() - if residual is not None: - assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k) - if residual.stride(-1) != 1 and residual.stride(1) != 1: - residual = residual.contiguous() - # Allocates output. - if out is not None: - assert out.shape == a.shape - assert out.stride(-1) == 1 or out.stride(1) == 1 - else: - out = torch.empty_like(a) - dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else - (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32)) - grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch, - nchunks if not has_groups else nchunks * ngroups) - residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2), - residual.stride(-1)) - if residual is not None else (0, 0, 0, 0)) - with torch.cuda.device(a.device.index): - _bmm_chunk_bwd_kernel[grid]( - a, dout, out, residual, - seqlen, chunk_size, k, ngroups if has_groups else 1, - a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), - dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1), - out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1), - residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3], - dot_dtype, - HAS_RESIDUAL=residual is not None, - ) - return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index c1fabf0ac5904..79fa52e0b8c4f 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -3,19 +3,13 @@ """We want triton==2.1.0 or 2.2.0 for this """ -import math from packaging import version import torch -import torch.nn.functional as F import triton import triton.language as tl -from einops import rearrange, repeat - -from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd - TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -172,1061 +166,6 @@ def _chunk_scan_fwd_kernel( out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) - -@triton.autotune( - configs=[ - # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), - # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), - triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), - ], - key=['chunk_size', 'hdim', 'dstate'], -) -@triton.jit -def _chunk_scan_fwd_kernel_wip( - # Pointers to matrices - cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr, - # Matrix dimensions - chunk_size, hdim, dstate, - batch, seqlen, nheads_ngroups_ratio, - # Strides - stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, - stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, - stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate, - stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, - stride_D_head, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - HAS_Z: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - pid_n = tl.program_id(axis=0) - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head - B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head - prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - - offs_m = tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE) - - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate) - prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) - - prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) - # if pid_c == 0: - # if pid_b == 0: - # if pid_h == 0: - # tl.device_print("", prev_states) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # scale_m = tl.exp(dA_cs_m) - # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] - # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32) - # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) - # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # cb *= dt_m - # mask = offs_m[:, None] >= offs_m[None, :] - # cb = tl.where(mask, cb, 0.0) - # cb = cb.to(x_ptr.dtype.element_ty) - # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0) - # acc += tl.dot(cb, x) - # if HAS_D: - # if D_HAS_HDIM: - # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - # else: - # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - # acc += x.to(tl.float32) * D - # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - - for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M): - start_m = tl.multiple_of(start_m, BLOCK_SIZE_M) - dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1) - if not HAS_SEQ_IDX: - scale_m = tl.exp(dA_cs_m) - else: - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0) - acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] - # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32) - # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) - # cb *= dt_m - # mask = offs_m[:, None] >= offs_m[None, :] - # cb = tl.where(mask, cb, 0.0) - # cb = cb.to(x_ptr.dtype.element_ty) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0) - # acc += tl.dot(cb, x) - - if HAS_D: - if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - acc += x.to(tl.float32) * D - - # if HAS_Z: - # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) - # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) - - # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head - # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) - # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) - # acc *= z * tl.sigmoid(z) - - tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim)) - - # TODO: this is not correct, and quite a bit slower - if start_m + BLOCK_SIZE_M < chunk_size_limit: - # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32) - B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0) - dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32) - # TODO: seq_idx - scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m - # B *= scale - B = B.to(x_ptr.dtype.element_ty) - tmp = tl.dot(B, x) - prev_states += tmp.to(prev_states.dtype) - - C_ptrs += BLOCK_SIZE_M * stride_C_seqlen - B_ptrs += BLOCK_SIZE_M * stride_B_seqlen - cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k - x_ptrs += BLOCK_SIZE_M * stride_x_seqlen - dt_ptrs += BLOCK_SIZE_M * stride_dt_csize - out_ptrs += BLOCK_SIZE_M * stride_out_seqlen - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 32}), - triton.Config({'BLOCK_SIZE_M': 64}), - triton.Config({'BLOCK_SIZE_M': 128}), - triton.Config({'BLOCK_SIZE_M': 256}), - ], - key=["chunk_size", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_dz_kernel( - # Pointers to matrices - dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, hdim, - batch, seqlen, - # Strides - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, - stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_D_head, - stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim, - stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim, - stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim, - stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - HAS_DDACS: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head - out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head - dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head - if RECOMPUTE_OUTPUT: - outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head - if HAS_DDACS: - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - if HAS_D: - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_N) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim) - out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) - z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim) - dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim) - if RECOMPUTE_OUTPUT: - outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim) - if HAS_D: - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - if D_HAS_HDIM: - dD_ptrs = dD_ptr + offs_n * stride_dD_hdim - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - z_sigmoid = tl.sigmoid(z) - if RECOMPUTE_OUTPUT: - outz = out * z * z_sigmoid - tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid)) - tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - dout *= z * z_sigmoid - tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - if HAS_D: - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - if D_HAS_HDIM: - dD = tl.sum(dout * x, axis=0) - tl.store(dD_ptrs, dD, mask=offs_n < hdim) - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - dD = tl.sum(dout * x) - tl.store(dD_ptr, dD) - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - out -= x * D - if HAS_DDACS: - ddA_cs = tl.sum(dout * out, axis=1) - tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), - ], - key=['hdim', 'dstate', 'chunk_size'], -) -@triton.jit -def _chunk_scan_bwd_dstates_kernel( - # Pointers to matrices - dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr, - # Matrix dimensions - hdim, dstate, chunk_size, - batch, seqlen, nchunks, nheads_ngroups_ratio, - # Strides - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate, - stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen) - c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale_k = tl.exp(dA_cs_k) - else: - seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) - scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0) - dout = (dout * scale_k).to(dout_ptr.dtype.element_ty) - c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0) - acc += tl.dot(dout, c) - dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen - c_ptrs += BLOCK_SIZE_K * stride_c_seqlen - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen - out = acc.to(dprev_states_ptr.dtype.element_ty) - - dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate) - tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)) - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], - key=['chunk_size', 'dstate', 'hdim'], -) -@triton.jit -def _chunk_scan_bwd_dc_kernel( - # Pointers to matrices - dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, - dc_ptr, ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, dstate, hdim, - batch, seqlen, nheads, nheads_per_program, ngroups, - # Strides - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, - stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, - # Meta-parameters - HAS_DDA_CS: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_sg = tl.program_id(axis=2) - pid_s = pid_sg // ngroups - pid_g = pid_sg - pid_s * ngroups - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head - dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split - prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head - if HAS_DDA_CS: - C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize - if HAS_DDA_CS: - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - if HAS_DDA_CS: - c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) - for h in range(nheads_iter): - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) - prev_states = prev_states.to(dout_ptrs.dtype.element_ty) - dc = tl.dot(dout, prev_states) - dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_m) - else: - scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - dc *= scale[:, None] - if HAS_DDA_CS: - ddA_cs = tl.sum(dc * c, axis=1) - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - acc += dc - dout_ptrs += stride_dout_head - prev_states_ptrs += stride_prev_states_head - dA_cumsum_ptrs += stride_dA_cs_head - if HAS_DDA_CS: - ddA_cumsum_ptrs += stride_ddA_cs_head - # if HAS_SEQ_IDX: - # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate) - tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - ], - key=['chunk_size', 'hdim'], -) -@triton.jit -def _chunk_scan_bwd_dx_kernel( - # Pointers to matrices - x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr, - dx_ptr, ddt_ptr, # dD_ptr, - # Matrix dimensions - chunk_size, hdim, - batch, seqlen, nheads_ngroups_ratio, - # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_D_head, - stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, - # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - # if HAS_D: - # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) - dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Idk why limiting K_MAX gives wrong results, is it a Triton bug? - # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) - K_MAX = chunk_size_limit - for k in range(0, K_MAX, BLOCK_SIZE_K): - # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) - dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) - # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, - # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. - # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. - # This will cause NaN in acc, and hence NaN in dx and ddt. - mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) - cb = tl.where(mask, cb, 0.0) - cb = cb.to(dout_ptr.dtype.element_ty) - acc += tl.dot(cb, dout) - cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k - dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - dx = acc * dt_m[:, None] - dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head - dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) - if HAS_D: - dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - dx += dout_res * D - tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - ddt = tl.sum(acc * x, axis=1) - ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) - - # if HAS_D: - # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) - # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32) - # dD = tl.sum(x * dout, axis=0) - # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) - - -# Disabling HAS_DDA_CS for now since it's much slower -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), - ], - key=['chunk_size', 'hdim'], -) -# @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) -# @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32}) -@triton.jit -def _chunk_scan_bwd_dcb_kernel( - # Pointers to matrices - x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, - dcb_ptr, ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, hdim, - batch, seqlen, nheads, nheads_per_program, ngroups, - # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, - # Meta-parameters - HAS_DDA_CS: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_sg = tl.program_id(axis=2) - pid_s = pid_sg // ngroups - pid_g = pid_sg - pid_s * ngroups - num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head - if HAS_DDA_CS: - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_n * stride_dt_csize - if HAS_DDA_CS: - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n - - if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: - dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split - dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) - tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) - return - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - if HAS_DDA_CS: - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) - nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) - for h in range(nheads_iter): - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) - dcb = tl.dot(dout, x) - dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - dcb *= dt_n - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) - dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - if HAS_DDA_CS: - tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") - ddA_cs = dcb * cb - mask = offs_m[:, None] >= offs_n[None, :] + 1 - ddA_cs = tl.where(mask, ddA_cs, 0.0) - ddA_cs = tl.cumsum(ddA_cs, axis=1) - ddA_cs = tl.where(mask, ddA_cs, 0.0) - ddA_cs = tl.sum(ddA_cs, axis=0) - tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) - tl.store(ddA_cumsum_ptr, 0.0) - acc += dcb - dout_ptrs += stride_dout_head - x_ptrs += stride_x_head - dt_ptrs += stride_dt_head - dA_cumsum_ptr += stride_dA_cs_head - if HAS_DDA_CS: - ddA_cumsum_ptr += stride_ddA_cs_head - ddA_cumsum_ptrs += stride_ddA_cs_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_SEQ_IDX: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) - mask = offs_m[:, None] >= offs_n[None, :] - acc = tl.where(mask, acc, 0.0) - dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split - dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) - tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) - - -# Not numerically stable and should not be used. Leaving here for reference. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 32}), - triton.Config({'BLOCK_SIZE_M': 64}), - triton.Config({'BLOCK_SIZE_M': 128}), - triton.Config({'BLOCK_SIZE_M': 256}), - ], - key=["chunk_size", "hdim"], -) -@triton.jit -def _chunk_scan_bwd_ddAcs_unstable_kernel( - # Pointers to matrices - dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr, - ddA_cumsum_ptr, dD_ptr, - # Matrix dimensions - chunk_size, hdim, - batch, seqlen, - # Strides - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_D_head, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, - stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - SUBTRACT_DDTDT: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - if HAS_D: - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_N) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) - if HAS_D: - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - if D_HAS_HDIM: - dD_ptrs = dD_ptr + offs_n * stride_dD_hdim - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - if HAS_D: - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - if D_HAS_HDIM: - dD = tl.sum(dout * x, axis=0) - tl.store(dD_ptrs, dD, mask=offs_n < hdim) - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - dD = tl.sum(dout * x) - tl.store(dD_ptr, dD) - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - out -= x * D - ddA_cs = tl.sum(dout * out, axis=1) - if SUBTRACT_DDTDT: - dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - ddA_cs -= dt * ddt - tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) - - -@triton.autotune( - configs=[ - # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), - ], - key=['chunk_size', 'hdim'], -) -@triton.jit -def _chunk_scan_bwd_ddAcs_stable_kernel_old( - # Pointers to matrices - x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, - ddAcs_ptr, - # Matrix dimensions - chunk_size, hdim, - batch, seqlen, nheads_ngroups_ratio, - # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, - stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_n * stride_dt_csize - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) - # Doing a matmul loop with cumsum later on will cause Triton to crash - # Instead we do just one big matmul - # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # for k in range(0, hdim, BLOCK_SIZE_K): - # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) - # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) - # acc += tl.dot(dout, x) - # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim - # x_ptrs += BLOCK_SIZE_K * stride_x_hdim - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) - acc = tl.dot(dout, x) - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) - acc *= cb - dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - acc *= dt_n - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - mask = offs_m[:, None] >= offs_n[None, :] + 1 - acc = tl.where(mask, acc, 0.0) - acc = tl.cumsum(acc, axis=1) - acc = tl.where(mask, acc, 0.0) - ddA_cs = tl.sum(acc, axis=0) - ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n - tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) - tl.store(ddAcs_ptr, 0.0) - - # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64) - # offs_k = tl.arange(0, BLOCK_SIZE_K) - # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) - # dt_ptrs = dt_ptr + offs_n * stride_dt_csize - # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) - - # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) - # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m - # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n - # for n in range(0, chunk_size_limit_n, 64): - # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0) - # acc = tl.dot(dout, x) - # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32) - # acc *= cb - # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) - # acc *= dt_n - # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) - # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n - # acc = tl.where(mask, acc, 0.0) - # acc = tl.cumsum(acc, axis=1) - # acc = tl.where(mask, acc, 0.0) - # ddA_cs = tl.sum(acc, axis=0) - # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n) - # # tl.store(ddAcs_ptr, 0.0) - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), - ], - key=['chunk_size', 'hdim'], -) -@triton.jit -def _chunk_scan_bwd_ddAcs_stable_kernel( - # Pointers to matrices - x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, hdim, - batch, seqlen, nheads_ngroups_ratio, - # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) - dt_ptrs = dt_ptr + offs_n * stride_dt_csize - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) - ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n - tl.store(ddA_cumsum_ptr, 0.0) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower - lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M - # lo, hi = 0, chunk_size - for start_n in range(lo, hi, BLOCK_SIZE_N): - start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) - # Doing a matmul loop with cumsum later on will cause Triton to crash - # Instead we do just one big matmul - # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # for k in range(0, hdim, BLOCK_SIZE_K): - # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) - # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) - # acc += tl.dot(dout, x) - # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim - # x_ptrs += BLOCK_SIZE_K * stride_x_hdim - # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) - acc = tl.dot(dout, x) - dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) - acc *= dt_n - # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) - acc *= cb - dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 - acc = tl.where(mask, acc, 0.0) - rowsum_new = rowsum + tl.sum(acc, axis=1) - acc = rowsum[:, None] + tl.cumsum(acc, axis=1) - rowsum = rowsum_new - acc = tl.where(mask, acc, 0.0) - ddA_cs = tl.sum(acc, axis=0) - tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) - x_ptrs += BLOCK_SIZE_N * stride_x_seqlen - dt_ptrs += BLOCK_SIZE_N * stride_dt_csize - cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n - ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n - - # Need to zero out the rest, since we'll be summing the rows together - for start_n in range(hi, chunk_size, BLOCK_SIZE_N): - tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1) - ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], - key=['chunk_size', 'dstate', 'hdim'], -) -@triton.jit -def _chunk_scan_bwd_ddAcs_prev_kernel( - # Pointers to matrices - dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, dstate, hdim, - batch, seqlen, nchunks, nheads_ngroups_ratio, - # Strides - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, - stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, - # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head - C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) - prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) - prev_states = prev_states.to(dout_ptrs.dtype.element_ty) - acc = tl.dot(dout, prev_states) - c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) - ddA_cs = tl.sum(acc * c, axis=1) - dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_m) - if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - ddA_cs *= scale - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - - def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape @@ -1276,554 +215,4 @@ def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=Non HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, ) - return out, out_x - - -def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = C.shape - assert nheads % ngroups == 0 - assert C.shape == (batch, seqlen, ngroups, dstate) - assert B.shape == C.shape - assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - if z is not None: - assert z.shape == x.shape - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - assert states.shape == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - # Allocates output. - out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) - if z is not None: - out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) - assert out_x.stride() == out.stride() - else: - out_x = None - grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) - z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) - if z is not None else (0, 0, 0, 0)) - _chunk_scan_fwd_kernel_wip[grid]( - cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D, - chunk_size, headdim, dstate, - batch, seqlen, nheads // ngroups, - cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - z_strides[0], z_strides[1], z_strides[2], z_strides[3], - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - C.stride(0), C.stride(1), C.stride(2), C.stride(3), - B.stride(0), B.stride(1), B.stride(2), B.stride(3), - states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), - D.stride(0) if D is not None else 0, - D is not None, - D.dim() == 2 if D is not None else True, - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - BLOCK_SIZE_M=128, - HAS_Z=z is not None, - HAS_SEQ_IDX=seq_idx is not None, - ) - return out, out_x - - -def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False): - batch, seqlen, nheads, headdim = x.shape - assert z.shape == x.shape - assert out.shape == x.shape - assert dout.shape == out.shape - nchunks = math.ceil(seqlen / chunk_size) - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) - assert D.stride(-1) == 1 - if has_ddAcs: - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) - if D is not None: - BLOCK_SIZE_min = 32 - dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, - headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) - else: - dD = None - if dz is not None: - assert dz.shape == z.shape - else: - dz = torch.empty_like(z) - if recompute_output: - outz = torch.empty_like(x) - dout_x = torch.empty_like(dout) - dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) - if D is not None else (0, 0, 0, 0, 0)) - grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) - with torch.cuda.device(x.device.index): - _chunk_scan_bwd_dz_kernel[grid_dz]( - dout, out, z, x, D, outz if recompute_output else None, - dz, dout_x, dD, ddA_cumsum if has_ddAcs else None, - chunk_size, headdim, - batch, seqlen, - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - z.stride(0), z.stride(1), z.stride(2), z.stride(3), - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - D.stride(0) if D is not None else 0, - *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)), - dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3), - dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3), - dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], - *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) - if has_ddAcs else (0, 0, 0, 0)), - D is not None, - D.dim() == 2 if D is not None else True, - has_ddAcs, - BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), - RECOMPUTE_OUTPUT=recompute_output, - ) - if D is not None: - BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) - if D.dim() == 1: - dD = rearrange(dD, "h 1 -> h") - return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD) - return return_vals if not recompute_output else (*return_vals, outz) - - -def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): - batch, seqlen, nheads, headdim = dout.shape - _, _, nchunks, chunk_size = dA_cumsum.shape - _, _, ngroups, dstate = C.shape - assert nheads % ngroups == 0 - assert C.shape == (batch, seqlen, ngroups, dstate) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - dtype = C.dtype if dtype is None else dtype - dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype) - grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) - with torch.cuda.device(C.device.index): - _chunk_scan_bwd_dstates_kernel[grid_dstates]( - dout, C, dprev_states, dA_cumsum, seq_idx, - headdim, dstate, chunk_size, - batch, seqlen, nchunks, nheads // ngroups, - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - C.stride(0), C.stride(1), C.stride(2), C.stride(3), - dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - HAS_SEQ_IDX=seq_idx is not None, - ) - return dprev_states - - -def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1): - batch, nchunks, nheads, headdim, dstate = prev_states.shape - _, seqlen, _, _ = dout.shape - _, _, _, chunk_size = dA_cumsum.shape - assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - assert dout.shape == (batch, seqlen, nheads, headdim) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if C is not None: - assert C.shape == (batch, seqlen, ngroups, dstate) - C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) - ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) - ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3)) - else: - C_strides = (0, 0, 0, 0) - ddA_cumsum_prev = None - ddA_cumsum_prev_strides = (0, 0, 0, 0) - nheads_ngroups_ratio = nheads // ngroups - sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count - nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) - nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) - dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32) - grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), - batch * nchunks, nsplits * ngroups) - with torch.cuda.device(dout.device.index): - _chunk_scan_bwd_dc_kernel[grid_dc]( - dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev, - chunk_size, dstate, headdim, - batch, seqlen, nheads, nheads_per_program, ngroups, - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), - *C_strides, - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), - *ddA_cumsum_prev_strides, - HAS_DDA_CS=ddA_cumsum_prev is not None, - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - dC = dC.sum(2) - return dC if C is None else (dC, ddA_cumsum_prev) - - -def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dout.shape == x.shape - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if CB is not None: - assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4)) - BLOCK_SIZE_M_min = 16 - ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), - chunk_size, device=x.device, dtype=torch.float32) - ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4)) - else: - CB_strides = (0, 0, 0, 0, 0) - ddA_cumsum = None - ddA_cumsum_strides = (0, 0, 0, 0, 0) - nheads_ngroups_ratio = nheads // ngroups - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count - nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) - nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) - dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32) - grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), - batch * nchunks, nsplits * ngroups) - with torch.cuda.device(x.device.index): - _chunk_scan_bwd_dcb_kernel[grid_dcb]( - x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum, - chunk_size, headdim, - batch, seqlen, nheads, nheads_per_program, ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - *CB_strides, - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5), - *ddA_cumsum_strides, - HAS_DDA_CS=ddA_cumsum is not None, - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - dcb = dcb.sum(2) - if ddA_cumsum is not None: - BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual - ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) - return dcb if CB is None else (dcb, ddA_cumsum) - - -def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - ngroups = cb.shape[2] - assert nheads % ngroups == 0 - assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dout.shape == x.shape - # if D is not None: - # BLOCK_SIZE_M_min = 32 - # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32) - # else: - # dD = None - dx = torch.empty_like(x) - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) - grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) - with torch.cuda.device(x.device.index): - _chunk_scan_bwd_dx_kernel[grid_dx]( - x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD, - chunk_size, headdim, - batch, seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2), - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - D.stride(0) if D is not None else 0, - dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), - # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, - D is not None, - D.dim() == 2 if D is not None else True, - ) - # if D is not None: - # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] - # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) - return dx, ddt.to(dtype=dt.dtype) - - -def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): - """Not numerically stable and should not be used. Leaving here for reference. - """ - - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert ddt.shape == dt.shape - assert out.shape == x.shape - assert dout.shape == x.shape - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) - ddA_cumsum = torch.empty_like(dt) - grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) - if D is not None: # Triton gives wrong results if we write to the same location - BLOCK_SIZE_min = 32 - dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, - headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) - else: - dD = None - dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) - if D is not None else (0, 0, 0, 0, 0)) - with torch.cuda.device(x.device.index): - _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs]( - dout, out, dt, ddt, x, D, ddA_cumsum, dD, - chunk_size, headdim, - batch, seqlen, - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - D.stride(0) if D is not None else 0, - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), - dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], - D is not None, - D.dim() == 2 if D is not None else True, - subtract_ddtdt, - BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), - ) - if D is not None: - BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) - if D.dim() == 1: - dD = rearrange(dD, "h 1 -> h") - return ddA_cumsum, dD - - -def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dout.shape == x.shape - assert dA_cumsum.shape == dt.shape - ngroups = cb.shape[2] - assert nheads % ngroups == 0 - assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - BLOCK_SIZE_M_min = 16 - ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), - chunk_size, device=x.device, dtype=torch.float32) - grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) - with torch.cuda.device(x.device.index): - _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs]( - x, dout, dt, dA_cumsum, cb, ddA_cumsum, - chunk_size, headdim, - batch, seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16), - ) - BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual - ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) - return ddA_cumsum - - -def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dout.shape == x.shape - assert dA_cumsum.shape == dt.shape - ngroups = cb.shape[2] - assert nheads % ngroups == 0 - assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - BLOCK_SIZE_M_min = 32 - ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), - chunk_size, device=x.device, dtype=torch.float32) - grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) - with torch.cuda.device(x.device.index): - _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( - x, dout, dt, dA_cumsum, cb, ddA_cumsum, - chunk_size, headdim, - batch, seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual - ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) - return ddA_cumsum - - -def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None): - batch, nchunks, nheads, headdim, dstate = prev_states.shape - _, seqlen, _, _ = dout.shape - _, _, _, chunk_size = dA_cumsum.shape - assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - assert dout.shape == (batch, seqlen, nheads, headdim) - ngroups = C.shape[2] - assert nheads % ngroups == 0 - assert C.shape == (batch, seqlen, ngroups, dstate) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) - grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) - with torch.cuda.device(dout.device.index): - _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs]( - dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev, - chunk_size, dstate, headdim, - batch, seqlen, nchunks, nheads // ngroups, - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), - C.stride(0), C.stride(1), C.stride(2), C.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3), - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - return ddA_cumsum_prev - - -class ChunkScanFn(torch.autograd.Function): - - @staticmethod - def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): - # Check constraints. - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape - assert B.shape == (batch, seqlen, ngroups, dstate) - _, _, nchunks, chunk_size = dt.shape - assert seqlen == nchunks * chunk_size - assert C.shape == B.shape - if z is not None: - assert z.shape == x.shape - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous - x = x.contiguous() - if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous - z = z.contiguous() - if D is not None and D.stride(-1) != 1: - D = D.contiguous() - CB = _bmm_chunk_fwd(C, B, chunk_size) - out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z) - ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z) - return out - - @staticmethod - def backward(ctx, dout): - if dout.stride(-1) != 1: - dout = dout.contiguous() - out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert dout.shape == (batch, seqlen, nheads, headdim) - if z is not None: - dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D) - else: - dz = None - dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype) - dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups) - dC = dC.to(C.dtype) - dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups) - dCB = dCB.to(CB.dtype) - dB = _bmm_chunk_bwd(C, dCB) - dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC) - dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D) - # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. - # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt - if z is not None: - ddA_cumsum -= ddt * dt - else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz - ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D) - ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype) - return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz - - -def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): - """ - prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1. - Argument: - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) - dA_cumsum: (batch, nheads, nchunks, chunk_size) - prev_states: (batch, nchunks, nheads, headdim, dstate) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - Return: - out: (batch, seqlen, nheads, headdim) - """ - return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z) - - -def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): - """ - Argument: - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) - dA_cumsum: (batch, nheads, nchunks, chunk_size) - prev_states: (batch, nchunks, nheads, headdim, dstate) - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - Return: - out: (batch, seqlen, nheads, headdim) - """ - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape - assert B.shape == (batch, seqlen, ngroups, dstate) - _, _, nchunks, chunk_size = dt.shape - assert seqlen == nchunks * chunk_size - assert C.shape == B.shape - B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) - C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) - CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), - rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) - # (batch, nheads, nchunks, chunksize, chunksize) - dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] - decay = torch.exp(dt_segment_sum) - scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") - causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) - scores_decay = scores_decay.masked_fill(~causal_mask, 0) - out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), - rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) - state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) - out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), - prev_states.to(C.dtype)) * state_decay_out - out = out + out_prev - out = rearrange(out, "b c l h p -> b (c l) h p") - if D is not None: - if D.dim() == 1: - D = rearrange(D, "h -> h 1") - out = out + x * D - return out if z is None else out * F.silu(z) + return out, out_x \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 5116735d2840b..3184bbbf03d41 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -83,85 +83,6 @@ def _chunk_cumsum_fwd_kernel( tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), - ], - key=['chunk_size', 'nheads'], -) -@triton.jit -def _chunk_cumsum_bwd_kernel( - # Pointers to matrices - ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr, - ddt_ptr, dA_ptr, ddt_bias_ptr, - # Matrix dimensions - batch, seqlen, nheads, chunk_size, - dt_min, dt_max, - # Strides - stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize, - stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize, - stride_dt_batch, stride_dt_seqlen, stride_dt_head, - stride_A_head, - stride_dt_bias_head, - stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head, - stride_dA_head, - stride_ddt_bias_head, - # Meta-parameters - DT_SOFTPLUS: tl.constexpr, - HAS_DT_BIAS: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, -): - pid_b = tl.program_id(axis=0) - pid_c = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk - ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk - dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen - ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen - - offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) - offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) - ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize) - ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize) - dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) - ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen) - A_ptrs = A_ptr + offs_h * stride_A_head - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) - ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) - A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) - ddt = ddA * A[:, None] + ddt_out - dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) - if HAS_DT_BIAS: - dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) - dt += dt_bias[:, None] - if DT_SOFTPLUS: - dt_presoftplus = dt - dt = tl.where(dt <= 20.0, softplus(dt), ddt) - clamp_mask = (dt < dt_min) | (dt > dt_max) - # As of Triton 2.2.0, tl.clamp is not available yet - # dt = tl.clamp(dt, dt_min, dt_max) - dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) - dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) - ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0) - ddt = tl.where(clamp_mask, 0.0, ddt) - if DT_SOFTPLUS: - ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt) - tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) - dA = tl.sum(ddA * dt, axis=1) - tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) - if HAS_DT_BIAS: - ddt_bias = tl.sum(ddt, axis=1) - tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) - - @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), @@ -253,327 +174,6 @@ def _chunk_state_fwd_kernel( c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), - ], - key=['chunk_size', 'hdim', 'dstate'], -) -@triton.jit -def _chunk_state_bwd_dx_kernel( - # Pointers to matrices - x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, - dx_ptr, ddt_ptr, ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, hdim, dstate, - batch, seqlen, nheads_ngroups_ratio, - # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, - stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head - dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) - dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) - if BLOCK_SIZE_DSTATE <= 128: - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc = tl.dot(b, dstates) - else: - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, dstate, BLOCK_SIZE_K): - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc += tl.dot(b, dstates) - b_ptrs += BLOCK_SIZE_K * stride_b_dstate - dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize - dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] - - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - ddt = tl.sum(acc * x, axis=1) - ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) - ddA_cs = -(ddt * dt_m) - ddA_cs_last = -tl.sum(ddA_cs) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) - - dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty) - dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head - dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) - tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], - key=['chunk_size', 'dstate', 'hdim'], -) -@triton.jit -def _chunk_state_bwd_db_kernel( - # Pointers to matrices - x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, - db_ptr, ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, dstate, hdim, - batch, seqlen, nheads, nheads_per_program, ngroups, - # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, - stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, - # Meta-parameters - HAS_DDA_CS: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_sg = tl.program_id(axis=2) - pid_s = pid_sg // ngroups - pid_g = pid_sg - pid_s * ngroups - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head - db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split - dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head - if HAS_DDA_CS: - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim) - dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize - if HAS_DDA_CS: - b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate) - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - if HAS_DDA_CS: - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) - for h in range(nheads_iter): - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) - dstates = dstates.to(x_ptrs.dtype.element_ty) - db = tl.dot(x, dstates) - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) - else: - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - db *= (scale * dt_m)[:, None] - if HAS_DDA_CS: - # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum - ddA_cs = tl.sum(db * b, axis=1) - tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) - acc += db - x_ptrs += stride_x_head - dstates_ptrs += stride_states_head - dt_ptrs += stride_dt_head - dA_cumsum_ptr += stride_dA_cs_head - dA_cumsum_ptrs += stride_dA_cs_head - if HAS_DDA_CS: - ddA_cumsum_ptrs += stride_ddA_cs_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - # if HAS_SEQ_IDX: - # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0) - db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate) - tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) - - -@triton.autotune( - configs=[ - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), - ], - key=['chunk_size', 'hdim', 'dstate'], -) -@triton.jit -def _chunk_state_bwd_ddAcs_stable_kernel( - # Pointers to matrices - x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, - ddA_cumsum_ptr, - # Matrix dimensions - chunk_size, hdim, dstate, - batch, seqlen, nheads_ngroups_ratio, - # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, - stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, - # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head - dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) - dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) - if BLOCK_SIZE_DSTATE <= 128: - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc = tl.dot(b, dstates) - else: - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, dstate, BLOCK_SIZE_K): - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc += tl.dot(b, dstates) - b_ptrs += BLOCK_SIZE_K * stride_b_dstate - dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) - else: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - acc *= scale[:, None] - - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - ddt = tl.sum(acc * x, axis=1) - # ddA_cs = -(ddt * dt_m) - # Triton 2.2.0 errors if we have the cumsum here, so we just write it out - # then call torch.cumsum outside this kernel. - # ddA_cs = tl.cumsum(ddt * dt_m) - ddA_cs = ddt * dt_m - ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize - # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) - tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) - - @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), @@ -690,44 +290,6 @@ def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_lim ) return dA_cumsum, dt_out - -def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None): - batch, seqlen, nheads = dt.shape - _, _, nchunks, chunk_size = ddA.shape - assert ddA.shape == (batch, nheads, nchunks, chunk_size) - assert ddt_out.shape == (batch, nheads, nchunks, chunk_size) - assert A.shape == (nheads,) - if dt_bias is not None: - assert dt_bias.shape == (nheads,) - ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) - else: - ddt_bias = None - if ddt is not None: - assert ddt.shape == dt.shape - else: - ddt = torch.empty_like(dt) - dA = torch.empty_like(A, dtype=torch.float32) - grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) - with torch.cuda.device(dt.device.index): - _chunk_cumsum_bwd_kernel[grid_chunk_cs]( - ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias, - batch, seqlen, nheads, chunk_size, - dt_limit[0], dt_limit[1], - ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3), - ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3), - dt.stride(0), dt.stride(1), dt.stride(2), - A.stride(0), - dt_bias.stride(0) if dt_bias is not None else 0, - ddt.stride(0), ddt.stride(1), ddt.stride(2), - dA.stride(0), - ddt_bias.stride(0) if ddt_bias is not None else 0, - dt_softplus, - HAS_DT_BIAS=dt_bias is not None, - BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), - ) - return ddt, dA, ddt_bias - - def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape @@ -760,130 +322,6 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_f ) return states - -def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) - if dx is not None: - assert dx.shape == x.shape - else: - dx = torch.empty_like(x) - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32) - grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) - with torch.cuda.device(x.device.index): - _chunk_state_bwd_dx_kernel[grid_dx]( - x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum, - chunk_size, headdim, dstate, - batch, seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - B.stride(0), B.stride(1), B.stride(2), B.stride(-1), - dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - ) - return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype) - - -def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - dstate = dstates.shape[-1] - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if B is not None: - assert B.shape == (batch, seqlen, ngroups, dstate) - B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3)) - # Use torch.empty since the Triton kernel will call init_to_zero - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) - ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) - else: - B_strides = (0, 0, 0, 0) - ddA_cumsum = None - ddA_cumsum_strides = (0, 0, 0, 0) - nheads_ngroups_ratio = nheads // ngroups - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count - nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) - nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) - dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32) - grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), - batch * nchunks, nsplits * ngroups) - with torch.cuda.device(x.device.index): - _chunk_state_bwd_db_kernel[grid_db]( - x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum, - chunk_size, dstate, headdim, - batch, seqlen, nheads, nheads_per_program, ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), - *B_strides, - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4), - *ddA_cumsum_strides, - HAS_DDA_CS=ddA_cumsum is not None, - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), - ) - dB = dB.sum(2) - if ddA_cumsum is not None: - # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute - # to the state of the chunk. - # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) - # But it's easier to just do the cumsum for all elements, the result will be the same. - torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum) - return dB if B is None else (dB, ddA_cumsum) - - -def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - # Use torch.empty since the Triton kernel will call init_to_zero - ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) - grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) - with torch.cuda.device(x.device.index): - _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs]( - x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum, - chunk_size, headdim, dstate, - batch, seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - B.stride(0), B.stride(1), B.stride(2), B.stride(-1), - dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - ) - torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) - return ddA_cumsum - - def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): total_seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape @@ -911,81 +349,3 @@ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): states.stride(0), states.stride(1), states.stride(2), states.stride(3), ) return states - - -class ChunkStateFn(torch.autograd.Function): - - @staticmethod - def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - assert seqlen <= nchunks * chunk_size - _, _, ngroups, dstate = B.shape - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - if B.stride(-1) != 1: - B = B.contiguous() - if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous - x = x.contiguous() - states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32) - ctx.save_for_backward(B, x, dt, dA_cumsum) - return states - - @staticmethod - def backward(ctx, dstates): - B, x, dt, dA_cumsum = ctx.saved_tensors - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) - if dstates.stride(-1) != 1: - dstates = dstates.contiguous() - dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates) - dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups) - dB = dB.to(B.dtype) - return dB, dx, ddt, ddA_cumsum, None - - -def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True): - """ - Argument: - B: (batch, seqlen, ngroups, headdim) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) - dA_cumsum: (batch, nheads, nchunks, chunk_size) - Return: - states: (batch, nchunks, nheads, headdim, dstate) - """ - return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32) - - -def chunk_state_ref(B, x, dt, dA_cumsum): - """ - Argument: - B: (batch, seqlen, ngroups, headdim) - x: (batch, seqlen, nheads, headdim) - dt: (batch, nheads, nchunks, chunk_size) - dA_cumsum: (batch, nheads, nchunks, chunk_size) - Return: - states: (batch, nchunks, nheads, headdim, dstate) - """ - # Check constraints. - batch, seqlen, nheads, headdim = x.shape - dstate = B.shape[-1] - _, _, nchunks, chunk_size = dt.shape - assert seqlen <= nchunks * chunk_size - assert x.shape == (batch, seqlen, nheads, headdim) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - ngroups = B.shape[2] - assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - if seqlen < nchunks * chunk_size: - x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) - B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) - x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) - B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) - decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) - return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index a6fb60c199667..728024a6b31fa 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -3,260 +3,26 @@ """We want triton==2.1.0 or 2.2.0 for this """ -import math from packaging import version import torch import triton -import triton.language as tl from einops import rearrange -from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd -from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd -from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_state import _chunk_cumsum_fwd +from .ssd_chunk_state import _chunk_state_fwd from .ssd_chunk_state import chunk_state_varlen -from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd -from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates -from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb -from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable +from .ssd_state_passing import _state_passing_fwd +from .ssd_chunk_scan import _chunk_scan_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), - ], - key=['chunk_size', 'hdim', 'dstate'], -) -@triton.jit -def _chunk_scan_chunk_state_bwd_dx_kernel( - # Pointers to matrices - x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr, - b_ptr, dstates_ptr, - dx_ptr, ddt_ptr, dD_ptr, - # Matrix dimensions - chunk_size, hdim, dstate, - batch, seqlen, nheads_ngroups_ratio, - # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, - stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_D_head, - stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, - stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, - stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, - stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, - stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, - # Meta-parameters - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, - IS_TRITON_22: tl.constexpr, -): - pid_bc = tl.program_id(axis=1) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head - dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_m) - else: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - # However, we're getting error with the Triton compiler 2.1.0 for that code path: - # Unexpected mma -> mma layout conversion - # Triton 2.2.0 fixes this - offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate) - dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate) - if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc = tl.dot(b, dstates) * scale[:, None] - else: - for k in range(0, dstate, BLOCK_SIZE_K): - b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0) - dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) - dstates = dstates.to(b_ptr.dtype.element_ty) - acc += tl.dot(b, dstates) - b_ptrs += BLOCK_SIZE_K * stride_b_dstate - dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate - acc *= scale[:, None] - - # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - # dt_ptrs = dt_ptr + offs_m * stride_dt_csize - # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - # ddt = tl.sum(acc * x, axis=1) * dt_m - # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) - - offs_k = tl.arange(0, BLOCK_SIZE_K) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) - dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - K_MAX = chunk_size_limit - K_MIN = pid_m * BLOCK_SIZE_M - cb_ptrs += K_MIN * stride_cb_csize_k - dout_ptrs += K_MIN * stride_dout_seqlen - dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize - for k in range(K_MIN, K_MAX, BLOCK_SIZE_K): - k = tl.multiple_of(k, BLOCK_SIZE_K) - # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) - dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) - # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, - # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. - # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. - # This will cause NaN in acc, and hence NaN in dx and ddt. - mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) - cb = tl.where(mask, cb, 0.0) - cb = cb.to(dout_ptr.dtype.element_ty) - acc += tl.dot(cb, dout) - cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k - dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dt_ptrs = dt_ptr + offs_m * stride_dt_csize - dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) - dx = acc * dt_m[:, None] - dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head - dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) - if HAS_D: - dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) - dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) - else: - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - dx += dout_res * D - tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) - - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) - x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) - if HAS_D: - dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize - if D_HAS_HDIM: - dD_ptrs = dD_ptr + offs_n * stride_dD_hdim - dD = tl.sum(dout_res * x, axis=0) - tl.store(dD_ptrs, dD, mask=offs_n < hdim) - else: - dD = tl.sum(dout_res * x) - tl.store(dD_ptr, dD) - ddt = tl.sum(acc * x, axis=1) - ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize - tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) - - -def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert dout.shape == x.shape - assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) - assert D.stride(-1) == 1 - BLOCK_SIZE_min = 32 - dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, - headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) - else: - dD = None - dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) - if D is not None else (0, 0, 0, 0, 0)) - if dx is None: - dx = torch.empty_like(x) - else: - assert dx.shape == x.shape - ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) - grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) - with torch.cuda.device(x.device.index): - _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( - x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD, - chunk_size, headdim, dstate, - batch, seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2), - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - D.stride(0) if D is not None else 0, - B.stride(0), B.stride(1), B.stride(2), B.stride(3), - dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), - dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), - ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), - dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], - D is not None, - D.dim() == 2 if D is not None else True, - HAS_SEQ_IDX=seq_idx is not None, - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - IS_TRITON_22=TRITON_22 - ) - if D is not None: - BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] - n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) - if D.dim() == 1: - dD = rearrange(dD, "h 1 -> h") - return dx, ddt.to(dtype=dt.dtype), dD - def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape @@ -309,156 +75,6 @@ def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, d cu_seqlens, states.squeeze(0)) return out, out_x, dt, dA_cumsum, states, final_states, varlen_states - -def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, - dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, - dt_limit=(0.0, float("inf")), - dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False): - if dout.stride(-1) != 1: - dout = dout.contiguous() - batch, seqlen, nheads, headdim = x.shape - nchunks = math.ceil(seqlen / chunk_size) - _, _, ngroups, dstate = B.shape - assert dout.shape == (batch, seqlen, nheads, headdim) - assert dt.shape == (batch, seqlen, nheads) - assert A.shape == (nheads,) - assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert C.shape == B.shape - assert out.shape == x.shape - if initial_states is not None: - assert initial_states.shape == (batch, nheads, headdim, dstate) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if dx is not None: - assert dx.shape == x.shape - if dB is not None: - assert dB.shape == B.shape - dB_given = dB - else: - dB_given = torch.empty_like(B) - if dC is not None: - assert dC.shape == C.shape - dC_given = dC - else: - dC_given = torch.empty_like(C) - if dz is not None: - assert z is not None - assert dz.shape == z.shape - if ddt is not None: - assert ddt.shape == dt.shape - ddt_given = ddt - else: - ddt_given = torch.empty_like(dt) - # TD: For some reason Triton (2.1.0 and 2.2.0) errors with - # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why. - dt_in = dt.clone() - dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, - dt_limit=dt_limit) - CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) - states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) - states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], - initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, - seq_idx=seq_idx, chunk_size=chunk_size) - states = rearrange(states, "... (p n) -> ... p n", n=dstate) - if z is not None: - dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output) - outz = rest[0] if recompute_output else out - else: - dz = None - outz = out - dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype) - # dstates has length nchunks, containing the gradient to initial states at index 0 and - # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1) - # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states - # will be used in matmul in the next kernels. - dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd( - rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], - rearrange(dstates, "... p n -> ... (p n)"), - dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None, - seq_idx=seq_idx, - has_initial_states=initial_states is not None, - dstates_dtype=x.dtype, - states_dtype=x.dtype, - chunk_size=chunk_size, - ) - # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and - # gradient to the final states at index (nchunks - 1) - # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1) - # The final states is not stored. - states = rearrange(states, "... (p n) -> ... p n", n=dstate) - dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate) - dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None - dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx) - # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups) - dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) - # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) - dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups) - # Computing ddA with the dcb kernel is much slower, so we're not using it for now - dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) - # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups) - dCB = dCB.to(CB.dtype) - _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given) - _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given) - # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate - # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16 - if z is None: - dD = dD_from_x - # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. - # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt - # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might - # be a lot of underflow. - - # This is already done as part of bwd_dC kernel - # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx) - ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum - ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1]) - # This is already done as part of bwd_dB kernel - # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx) - # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j] - ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB) - ddA += ddA_next + ddA_prev - - ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given) - - # These 2 lines are just to test ddt and dA being computed by old code - # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z) - # ddt_given.copy_(ddt) - - return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states) - return return_vals if not recompute_output else (*return_vals, outz) - -class MambaChunkScanCombinedFn(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): - ctx.dt_dtype = dt.dtype - if not return_varlen_states: - cu_seqlens = None - else: - assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" - out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) - ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx) - ctx.dt_softplus = dt_softplus - ctx.chunk_size = chunk_size - ctx.dt_limit = dt_limit - ctx.return_final_states = return_final_states - ctx.return_varlen_states = return_varlen_states - if not return_varlen_states: - return out if not return_final_states else (out, final_states) - else: - varlen_states = rest[0] - return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) - - @staticmethod - def backward(ctx, dout, *args): - out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors - assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward" - dfinal_states = args[0] if ctx.return_final_states else None - dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) - return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None - def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): """ Argument: @@ -478,4 +94,14 @@ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bia Return: out: (batch, seqlen, nheads, headdim) """ - return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states) \ No newline at end of file + + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 63863b8236e1c..59ed1d17cfda2 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -3,15 +3,11 @@ """We want triton==2.1.0 or 2.2.0 for this """ -import math import torch -import torch.nn.functional as F import triton import triton.language as tl -from einops import rearrange, repeat - @triton.autotune( configs=[ @@ -85,112 +81,6 @@ def _state_passing_fwd_kernel( out_ptrs += stride_out_chunk -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE': 64}), - triton.Config({'BLOCK_SIZE': 128}), - triton.Config({'BLOCK_SIZE': 256}), - triton.Config({'BLOCK_SIZE': 512}), - triton.Config({'BLOCK_SIZE': 1024}), - triton.Config({'BLOCK_SIZE': 2048}), - ], - key=['dim'], -) -@triton.jit -def _state_passing_bwd_kernel( - # Pointers to matrices - dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr, - dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr, - # Matrix dimensions - dim, nchunks, seqlen, chunk_size, - # Strides - stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim, - stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, - stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim, - stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, - stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim, - # Meta-parameters - CONVERT_STATES: tl.constexpr, - HAS_DFINAL_STATES: tl.constexpr, - HAS_DINITSTATES: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk - ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk - dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk - if CONVERT_STATES: - states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk - if HAS_DFINAL_STATES: - dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head - if HAS_DINITSTATES: - dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch - - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - dout_ptrs = dout_ptr + offs_m * stride_dout_dim - if CONVERT_STATES: - states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim - - if HAS_DFINAL_STATES: - dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32) - else: - dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - tl.store(dstates_ptrs, dstates, mask=offs_m < dim) - if HAS_SEQ_IDX: - seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen) - dstates_ptrs -= stride_dstates_chunk - for c in range(nchunks - 1): - dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) - if HAS_SEQ_IDX: - seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen)) - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) - seq_idx = seq_idx_new - out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if CONVERT_STATES: - tl.store(states_converted_ptrs, out, mask=offs_m < dim) - ddA = tl.sum(out * dstates) * scale - tl.store(ddA_cs_ptr, ddA) - dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - dstates = scale * dstates + dout - tl.store(dstates_ptrs, dstates, mask=offs_m < dim) - dout_ptrs -= stride_dout_chunk - dstates_ptrs -= stride_dstates_chunk - dA_cs_ptr -= stride_dA_cs_chunk - ddA_cs_ptr -= stride_ddA_cs_chunk - out_ptrs -= stride_out_chunk - if CONVERT_STATES: - states_converted_ptrs -= stride_out_chunk - if CONVERT_STATES: - out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - tl.store(states_converted_ptrs, out, mask=offs_m < dim) - if not HAS_DINITSTATES: - tl.store(ddA_cs_ptr, 0.0) - else: - dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) - if HAS_SEQ_IDX: - scale = tl.where(seq_idx == 0, scale, 0.0) - out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - ddA = tl.sum(out * dstates) * scale - tl.store(ddA_cs_ptr, ddA) - dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - dstates = scale * dstates + dout - tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim) - - def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None): batch, nchunks, nheads, dim = states.shape @@ -220,129 +110,3 @@ def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=Non HAS_SEQ_IDX=seq_idx is not None, ) return out, final_states - - -def _state_passing_bwd( - states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None, - dstates_dtype=None, states_dtype=None, chunk_size=None -): - """ - states contains the initial_states at index 0. The final states are not included in states. - """ - batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) - assert dout.shape == (batch, nchunks, nheads, dim) - if seq_idx is not None: - assert chunk_size is not None - seqlen = seq_idx.shape[-1] - assert seq_idx.shape == (batch, seqlen) - dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) - if states_dtype is not None and states_dtype != states.dtype: - states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) - assert states_converted.stride() == states.stride() - else: - states_converted = None - if has_initial_states: - dinitstates = torch.empty_like(dstates[:, 0]) - else: - dinitstates = None - if dfinal_states is not None: - assert dfinal_states.shape == (batch, nheads, dim) - BLOCK_SIZE_min = 64 - n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min - ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks, - dtype=torch.float32, device=dA_chunk_cumsum.device) - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) - with torch.cuda.device(dout.device.index): - _state_passing_bwd_kernel[grid]( - dout, states, dA_chunk_cumsum, dfinal_states, seq_idx, - dstates, ddA_chunk_cumsum, dinitstates, states_converted, - dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, - dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), - states.stride(0), states.stride(1), states.stride(2), states.stride(3), - dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), - *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2)) - if dfinal_states is not None else (0, 0, 0)), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), - ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1), - *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2)) - if dinitstates is not None else (0, 0, 0)), - CONVERT_STATES=states_converted is not None, - HAS_DFINAL_STATES=dfinal_states is not None, - HAS_DINITSTATES=dinitstates is not None, - HAS_SEQ_IDX=seq_idx is not None, - ) - BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"] - n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual - ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype) - if states_dtype is not None and states_dtype == states.dtype: - states_converted = states - return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted) - - -class StatePassingFn(torch.autograd.Function): - - @staticmethod - def forward(ctx, states, dA_chunk_cumsum, initial_states=None): - batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) - if states.stride(-1) != 1: - states = states.contiguous() - out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states) - ctx.save_for_backward(out, dA_chunk_cumsum) - ctx.has_initial_states = initial_states is not None - return out, final_states - - @staticmethod - def backward(ctx, dout, dfinal_states): - out, dA_chunk_cumsum = ctx.saved_tensors - batch, nchunks, nheads, dim = out.shape - assert dout.shape == (batch, nchunks, nheads, dim) - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) - assert dfinal_states.shape == (batch, nheads, dim) - if dout.stride(-1) != 1: - dout = dout.contiguous() - dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd( - out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states - ) - return dstates, ddA_chunk_cumsum, dinitstates - - -def state_passing(states, dA_chunk_cumsum, initial_states=None): - """ - Argument: - states: (batch, nchunks, nheads, dim) - dA_chunk_cumsum: (batch, nheads, nchunks) - initial_states: (batch, nheads, dim) - Return: - out: (batch, nchunks, nheads, dim) - final_states: (batch, nheads, dim) - """ - return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states) - - -def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): - """ - Argument: - states: (batch, nchunks, nheads, dim) - dA_chunk_cumsum: (batch, nheads, nchunks) - initial_states: (batch, nheads, dim) - Return: - out: (batch, nchunks, nheads, dim) - final_states: (batch, nheads, dim) - """ - if initial_states is None: - initial_states = torch.zeros_like(states[:, 0]) - states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1) - dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0)) - dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1) - nchunks = dA_chunk_cumsum.shape[-1] - # (batch, nheads, nchunks, nchunks) - dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :] - # (batch, nheads, nchunks, nchunks) - decay_chunk = torch.exp(dt_chunk_segment_sum) - causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0) - decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) - out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states) - return out[:, :-1], out[:, -1] diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index a12ee30798c68..5c6a8ab043170 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -8,7 +8,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -28,8 +28,6 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) from .interfaces import HasInnerState, SupportsLoRA from .utils import maybe_prefix @@ -418,7 +416,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( + max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) From b2dc5cad298ffeb5c4d24209a1686532e7d75abc Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 06:23:45 +0000 Subject: [PATCH 06/19] fmt + lint Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_bamba.py | 7 +- .../layers/mamba/mamba_mixer2.py | 233 +++---- .../layers/mamba/ops/softplus.py | 10 +- .../layers/mamba/ops/ssd_bmm.py | 213 +++++-- .../layers/mamba/ops/ssd_chunk_scan.py | 393 +++++++++--- .../layers/mamba/ops/ssd_chunk_state.py | 592 ++++++++++++++---- .../layers/mamba/ops/ssd_combined.py | 131 +++- .../layers/mamba/ops/ssd_state_passing.py | 102 ++- vllm/model_executor/models/bamba.py | 61 +- 9 files changed, 1313 insertions(+), 429 deletions(-) diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py index a3bcb644baf8b..d266135360563 100644 --- a/tests/models/decoder_only/language/test_bamba.py +++ b/tests/models/decoder_only/language/test_bamba.py @@ -1,6 +1,6 @@ """Compare the outputs of HF and vLLM when using greedy sampling for Mamba. -This actually is really indentical to test_mamba, so maybe we can reuse +This actually is really identical to test_mamba, so maybe we can reuse Run `pytest tests/models/decoder_only/language/test_bamba.py`. """ @@ -97,6 +97,7 @@ def test_batching( name_1="batched_vllm", ) + @pytest.mark.skip("bamba does not support chunked prefill yet") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @@ -122,6 +123,7 @@ def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + @pytest.mark.skip("bamba does not support chunked prefill yet") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @@ -205,7 +207,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size( + len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b2a4b2aaefc78..150ee86b4ca3b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,36 +1,35 @@ +from typing import List, Optional, Tuple, Union + import torch from torch import nn -from torch.nn.parameter import Parameter - -# Added by the IBM Team, 2024 from vllm.attention.backends.abstract import AttentionMetadata +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) - -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_state_update) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + LoaderFunction, composed_weight_loader, sharded_weight_loader) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs -from vllm.distributed import (divide, get_tensor_model_parallel_world_size, - get_tensor_model_parallel_rank, - tensor_model_parallel_all_reduce) -from vllm.model_executor.model_loader.weight_utils import ( - composed_weight_loader, sharded_weight_loader, LoaderFunction) -from typing import Tuple, Union, Optional, List -from vllm.model_executor.custom_op import CustomOp +# Added by the IBM Team, 2024 + # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated # also referenced https://github.com/vllm-project/vllm/pull/9292 @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): + def __init__(self, hidden_size, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -84,6 +83,7 @@ def forward_cuda( ) return out + def extra_groups_for_head_shards(ngroups: int, tp_size: int): """Compute the extra (logical) groups to account for head shards""" @@ -93,12 +93,16 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int): return tp_size - ngroups % tp_size + def mamba_v2_sharded_weight_loader( - shard_spec: List[int], tp_size: int, tp_rank: int, + shard_spec: List[Tuple[int, int, float]], + tp_size: int, + tp_rank: int, ) -> LoaderFunction: - """Create a weight loader for mamba v2. This ensures that the projections are - correctly sharded so that they can be split into x, B, C. It also ensures the - the all the groups corresponding to a head shard is placed together with it. + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures the the all the groups corresponding to a head shard is placed + together with it. """ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: @@ -116,18 +120,21 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: rank = tp_rank // ratio # - should start from here (determined by rank) - loaded_skip = rank * shard_size # take these number dims from loaded + # - take these number dims from loaded + loaded_skip = rank * shard_size loaded_start_idx = loaded_boundary + loaded_skip # - these many number dims to take from loaded_weight take = min(shard_size, full_dim - extra - loaded_skip) # - always shard on dim 0 - param.data[ - boundary:boundary+take,... - ] = loaded_weight[ - loaded_start_idx:loaded_start_idx+take - ] + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[boundary:(boundary + take), # type: ignore[misc] + ...] = loaded_weight[ + loaded_start_idx:( # type: ignore[misc] + loaded_start_idx + take)] # type: ignore[misc] # move boundaries boundary += shard_size @@ -135,8 +142,9 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: return loader + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer -@CustomOp.register("mamba_mixer2") +@CustomOp.register("mamba_mixer2") class MambaMixer2(CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute @@ -165,17 +173,17 @@ def __init__(self, super().__init__() # For TP, the sharding plan is as follows: - # - for the conv modules, since + # - for the conv modules, since # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, # we shard intermediate_size and n_groups # - since intermediate_size = n_heads * head_dim, sharding on # intermediate_size is achieved by sharding on n_heads. - # - so if world_size divides groups, then sharding + # - so if world_size divides groups, then sharding # (n_groups / world_size, n_heads / world_size) # also maintains the invariant n_heads % n_groups == 0 - # - HOWEVER< if world_size DOES NOT divide groups, then we need to allocate - # extra space in the shard, such that the WHOLE GROUP must be placed - # together with the HEAD SHARD. + # - HOWEVER< if world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that the WHOLE GROUP + # must be placed together with the HEAD SHARD. self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -190,14 +198,14 @@ def __init__(self, self.n_groups = n_groups if n_groups % self.tp_size != 0: - # - for TP we shard conv_dim by sharding on n_groups, - # - but if n_groups cannot divide tp_size, we need to + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to # extend some extra groups - self.n_groups = n_groups + extra_groups_for_head_shards(n_groups, self.tp_size) + self.n_groups = n_groups + extra_groups_for_head_shards( + n_groups, self.tp_size) - self.conv_dim = ( - intermediate_size + 2 * self.n_groups * ssm_state_size - ) + self.conv_dim = (intermediate_size + + 2 * self.n_groups * ssm_state_size) self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, output_size=self.conv_dim, @@ -210,62 +218,76 @@ def __init__(self, # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size + self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config) + self.in_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size + + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config) - # - because in_proj is a concatenation of 3 weights, we + # - because in_proj is a concatenation of 3 weights, we # need to interleave them before sharding # - use the custom weight loader mamba_v2_sharded_weight_loader # for conv1d.bias, covn1d.weight and in_proj.weight # - need to set these settings, to assign the groups to the head shards group_shard_settings = ( - self.n_groups * self.ssm_state_size, # expected model size - (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned - self.num_heads // n_groups, # ratio for mapping back to original group + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + self.num_heads // + n_groups, # ratio for mapping back to original group ) intemediate_settings = (intermediate_size, 0, 1) head_setings = (self.num_heads, 0, 1) delattr(self.conv1d.bias, "weight_loader") - set_weight_attrs(self.conv1d.bias, { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intemediate_settings, group_shard_settings, group_shard_settings, - ], - self.tp_size, tp_rank, - ) - }) + set_weight_attrs( + self.conv1d.bias, { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intemediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }) delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs(self.conv1d.weight, { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intemediate_settings, group_shard_settings, group_shard_settings, - ], - self.tp_size, tp_rank - ) - }) + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + intemediate_settings, + group_shard_settings, + group_shard_settings, + ], self.tp_size, tp_rank) + }) delattr(self.in_proj.weight, "weight_loader") - set_weight_attrs(self.in_proj.weight, { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intemediate_settings, # for gate - intemediate_settings, group_shard_settings, group_shard_settings, - head_setings, # for dt - ], - self.tp_size, tp_rank - ) - }) - - # - these are TPed by heads to reduce the size of the + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intemediate_settings, # for gate + intemediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) + + # - these are TPed by heads to reduce the size of the # temporal shape self.A = nn.Parameter( torch.empty( - divide(num_heads, self.tp_size), dtype=torch.float32, + divide(num_heads, self.tp_size), + dtype=torch.float32, )) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) @@ -277,16 +299,14 @@ def __init__(self, set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - self.out_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=use_bias, - input_is_parallel=True, - quant_config=quant_config) + self.out_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config) - self.norm = Mixer2RMSNormGated( - intermediate_size // self.tp_size, eps=rms_norm_eps - ) + self.norm = Mixer2RMSNormGated(intermediate_size // self.tp_size, + eps=rms_norm_eps) def forward_native(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, @@ -297,27 +317,27 @@ def forward_cuda(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams): - seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size # - doing it differently from mixer v1; little confused with its logic - # - we need to do is to detect if there is any prefill; if there are + # - we need to do is to detect if there is any prefill; if there are # no prefils, then each example will be coming in one sample at a time - # - on the other hand v1 checks for "query_start_loc" and "context_lens_tensor" - # however we have noticed that, even when the samples are coming in - # one at a time, they are still non-NO.e + # - on the other hand v1 checks for "query_start_loc" + # and "context_lens_tensor" however we have noticed that, even + # when the samples are coming in + # one at a time, they are still not NONE, e.g., # * "query_start_loc" = [0, 1, ..] # * "context_lens_tensor" = [8, ...] - has_prefill = attn_metadata.num_prefills > 0 + has_prefill = attn_metadata.num_prefills > 0 # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) gate, hidden_states_B_C, dt = torch.split( projected_states, [ - self.intermediate_size // self.tp_size, - self.conv_dim // self.tp_size, + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, self.num_heads // self.tp_size, ], dim=-1, @@ -335,7 +355,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, # |-------------------- seq_len ---------------------| # |-- query_len ---| - # - "cache_indices" upates the conv_state cache in positions + # - "cache_indices" updates the conv_state cache in positions # pointed to by "mamba_cache_params.state_indices_tensor" hidden_states_B_C = causal_conv1d_fn( hidden_states_B_C.transpose(0, 1), @@ -345,8 +365,8 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc - ).transpose(0, 1)[:seq_len] + query_start_loc=attn_metadata.query_start_loc).transpose( + 0, 1)[:seq_len] else: hidden_states_B_C = causal_conv1d_update( hidden_states_B_C, @@ -354,14 +374,13 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor - ) + conv_state_indices=mamba_cache_params.state_indices_tensor) # - get hidden_states, B and C after depthwise convolution. hidden_states, B, C = torch.split( hidden_states_B_C, [ - self.intermediate_size // self.tp_size, + self.intermediate_size // self.tp_size, groups_time_state_size // self.tp_size, groups_time_state_size // self.tp_size, ], @@ -370,12 +389,12 @@ def forward_cuda(self, hidden_states: torch.Tensor, # 3. State Space Model sequence transformation if has_prefill: - + # FIXME: we are having problems using mamba_chunk_scan_combined # with chunked prefill. This is because there is no # initial_states requires initial_states.shape[0] to match # the batch size, but cu_seqlens requires batch_size = 1. - # Therefore as of now, initial_states and cu_seqlens are + # Therefore as of now, initial_states and cu_seqlens are # mutually exclusive. initial_states = None @@ -385,7 +404,8 @@ def forward_cuda(self, hidden_states: torch.Tensor, # ] scan_output, varlen_state = mamba_chunk_scan_combined( - hidden_states.view(1, seq_len, self.num_heads // self.tp_size, self.head_dim), + hidden_states.view(1, seq_len, self.num_heads // self.tp_size, + self.head_dim), dt.unsqueeze(0), self.A, B.view(1, seq_len, self.n_groups // self.tp_size, -1), @@ -412,15 +432,17 @@ def forward_cuda(self, hidden_states: torch.Tensor, hidden_states = scan_output.view(seq_len, -1) else: - # NOTE: can be optimized? + # NOTE: can be optimized? n_groups = self.n_groups // self.tp_size - A = self.A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt = dt[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) B = B.view(-1, n_groups, B.shape[1] // n_groups) C = C.view(-1, n_groups, C.shape[1] // n_groups) - hidden_states_reshaped = hidden_states.view(-1, self.num_heads // self.tp_size, self.head_dim) + hidden_states_reshaped = hidden_states.view( + -1, self.num_heads // self.tp_size, self.head_dim) # - the hidden is reshaped into number of current batches # - in this case there is no more prefil, so the batches gen @@ -434,22 +456,21 @@ def forward_cuda(self, hidden_states: torch.Tensor, mamba_cache_params.ssm_state, hidden_states_reshaped, dt, - A, + A, B, C, - D, + D, z=None, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=mamba_cache_params.state_indices_tensor, ) hidden_states = hidden_states.view( - -1, (self.num_heads // self.tp_size) * self.head_dim - ) + -1, (self.num_heads // self.tp_size) * self.head_dim) # # 4. gated MLP hidden_states = self.norm(hidden_states, gate) # # 5. Final linear projection out, _ = self.out_proj(hidden_states) - return out \ No newline at end of file + return out diff --git a/vllm/model_executor/layers/mamba/ops/softplus.py b/vllm/model_executor/layers/mamba/ops/softplus.py index 5541655c66160..5ec75be51bf3b 100644 --- a/vllm/model_executor/layers/mamba/ops/softplus.py +++ b/vllm/model_executor/layers/mamba/ops/softplus.py @@ -1,15 +1,21 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/softplus.py + +# ruff: noqa: E501 + import triton import triton.language as tl from packaging import version TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") - if TRITON3: + @triton.jit def softplus(dt): return tl.math.log(tl.math.exp(dt) + 1) else: + @triton.jit def softplus(dt): - return tl.math.log1p(tl.exp(dt)) \ No newline at end of file + return tl.math.log1p(tl.exp(dt)) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 312a65769b634..3eba3c49b4590 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -1,51 +1,134 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_bmm.py +# ruff: noqa: E501,SIM102 """We want triton==2.1.0 or 2.2.0 for this """ import math -import torch -import torch.nn.functional as F +import torch import triton import triton.language as tl -from einops import rearrange, repeat - def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + return lambda nargs: [ + nargs[name].zero_() for name in names if nargs[name] is not None + ] @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), ], key=['chunk_size', 'K', 'IS_CAUSAL'], ) @triton.jit def _bmm_chunk_fwd_kernel( # Pointers to matrices - a_ptr, b_ptr, out_ptr, seq_idx_ptr, + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, # Matrix dimensions - seqlen, chunk_size, K, ngroups, - stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, - stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk, - stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn, - stride_seq_idx_batch, stride_seq_idx_seqlen, + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, # Meta-parameters IS_CAUSAL: tl.constexpr, dot_dtype: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_ch = tl.program_id(axis=2).to(tl.int64) @@ -65,14 +148,22 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_n[None, :] * stride_b_seqlen) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype) - b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype) + a = tl.load(a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & + (offs_n[None, :] < chunk_size_limit), + other=0.0).to(dot_dtype) acc += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -81,16 +172,30 @@ def _bmm_chunk_fwd_kernel( offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_SEQ_IDX: chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) - tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & + (offs_n[None, :] < chunk_size)) + -def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): +def _bmm_chunk_fwd(a, + b, + chunk_size, + seq_idx=None, + causal=False, + output_dtype=None): """ Argument: a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) @@ -117,20 +222,44 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=No nchunks = math.ceil(seqlen / chunk_size) # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype - out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size), - device=a.device, dtype=out_dtype) - dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else - (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32)) - grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), - batch, nchunks if not has_groups else nchunks * ngroups) + out = torch.empty( + (batch, nchunks, chunk_size, chunk_size) if not has_groups else + (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) + dot_dtype = (tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 + or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + chunk_size, META['BLOCK_SIZE_N']), batch, nchunks + if not has_groups else nchunks * ngroups) with torch.cuda.device(a.device.index): _bmm_chunk_fwd_kernel[grid]( - a, b, out, seq_idx, - seqlen, chunk_size, k, ngroups if has_groups else 1, - a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), - b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1), - out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), causal, dot_dtype, HAS_SEQ_IDX=seq_idx is not None, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 79fa52e0b8c4f..c538aaa464171 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -1,55 +1,175 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_chunk_scan.py +# ruff: noqa: E501 """We want triton==2.1.0 or 2.2.0 for this """ -from packaging import version - import torch - import triton import triton.language as tl +from packaging import version TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + return lambda nargs: [ + nargs[name].zero_() for name in names if nargs[name] is not None + ] @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), ], key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices - cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + prev_states_ptr, + D_ptr, # Matrix dimensions - chunk_size, hdim, dstate, - batch, seqlen, nheads_ngroups_ratio, + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, # Strides - stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, - stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, - stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, - stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, stride_D_head, # Meta-parameters IS_CAUSAL: tl.constexpr, @@ -57,7 +177,9 @@ def _chunk_scan_fwd_kernel( D_HAS_HDIM: tl.constexpr, HAS_Z: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, ): @@ -68,23 +190,31 @@ def _chunk_scan_fwd_kernel( num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + ( + pid_h // nheads_ngroups_ratio) * stride_cb_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_C_head prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, + mask=offs_m < chunk_size, + other=0.0).to(tl.float32) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=pid_c >= 1, + other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Without the if (pid_c > -1), with Triton 2.1.0, I get @@ -92,23 +222,40 @@ def _chunk_scan_fwd_kernel( # With Triton 2.2.0, this works if IS_TRITON_22 or pid_c > -1: # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * stride_states_hdim + + offs_k_dstate[:, None] * stride_states_dstate) if not HAS_SEQ_IDX: scale_m = tl.exp(dA_cs_m) else: scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) if BLOCK_SIZE_DSTATE <= 128: - C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K @@ -116,24 +263,36 @@ def _chunk_scan_fwd_kernel( acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) - x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim) dt_ptrs = dt_ptr + offs_k * stride_dt_csize dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + K_MAX = chunk_size_limit if not IS_CAUSAL else min( + (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) for k in range(0, K_MAX, BLOCK_SIZE_K): - cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb = tl.load(cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: mask = offs_m[:, None] >= k + offs_k[None, :] cb = tl.where(mask, cb, 0.0) cb = cb.to(x_ptr.dtype.element_ty) - x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < hdim), + other=0.0) acc += tl.dot(cb, x) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k x_ptrs += BLOCK_SIZE_K * stride_x_seqlen @@ -145,28 +304,54 @@ def _chunk_scan_fwd_kernel( if HAS_D: if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, + mask=offs_n < hdim, + other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), - mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_n[None, :] < hdim), + other=0.0).to(tl.float32) acc += x_residual * D if HAS_Z: out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) - tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :]) + tl.store(out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head - z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) - z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim), + other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) - tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + -def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): +def _chunk_scan_fwd(cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape @@ -176,36 +361,88 @@ def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=Non if z is not None: assert z.shape == x.shape if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Allocates output. - out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + out = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) if z is not None: - out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + out_x = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) assert out_x.stride() == out.stride() else: out_x = None - grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) - z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) - if z is not None else (0, 0, 0, 0)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel[grid]( - cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D, - chunk_size, headdim, dstate, - batch, seqlen, nheads // ngroups, - cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - z_strides[0], z_strides[1], z_strides[2], z_strides[3], - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - C.stride(0), C.stride(1), C.stride(2), C.stride(3), - states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else + (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), D.stride(0) if D is not None else 0, True, D is not None, @@ -215,4 +452,4 @@ def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=Non HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, ) - return out, out_x \ No newline at end of file + return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 3184bbbf03d41..bafdcd2585e5a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -1,22 +1,24 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_chunk_state.py +# ruff: noqa: E501 """We want triton==2.1.0 or 2.2.0 for this """ import math -import torch -import torch.nn.functional as F +import torch import triton import triton.language as tl -from einops import rearrange, repeat - from .softplus import softplus def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + return lambda nargs: [ + nargs[name].zero_() for name in names if nargs[name] is not None + ] + @triton.autotune( configs=[ @@ -33,20 +35,37 @@ def init_to_zero(names): @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices - dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, # Matrix dimension - batch, seqlen, nheads, chunk_size, - dt_min, dt_max, + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, # Strides - stride_dt_batch, stride_dt_seqlen, stride_dt_head, + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, stride_A_head, stride_dt_bias_head, - stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, # Meta-parameters DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, ): pid_b = tl.program_id(axis=0) @@ -60,60 +79,165 @@ def _chunk_cumsum_fwd_kernel( offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) - dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + + offs_c[None, :] * stride_dt_seqlen) A_ptrs = A_ptr + offs_h * stride_A_head - dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize) - dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + + offs_c[None, :] * stride_dA_cs_csize) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + dt = tl.load(dt_ptrs, + mask=(offs_h[:, None] < nheads) & + (offs_c[None, :] < chunk_size_limit), + other=0.0).to(tl.float32) if HAS_DT_BIAS: - dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, + mask=offs_h < nheads, + other=0.0).to(tl.float32) dt += dt_bias[:, None] if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) - dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) - tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + 0.0) + tl.store(dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) - tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + tl.store(dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), ], key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit def _chunk_state_fwd_kernel( # Pointers to matrices - x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, # Matrix dimensions - hdim, dstate, chunk_size, - batch, seqlen, nheads_ngroups_ratio, + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, # Strides - stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, - stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, - stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch @@ -122,7 +246,8 @@ def _chunk_state_fwd_kernel( num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head @@ -132,30 +257,46 @@ def _chunk_state_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize if HAS_SEQ_IDX: seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) if HAS_SEQ_IDX: - seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) if HAS_SEQ_IDX: - seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) if not HAS_SEQ_IDX: - scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k else: - scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -170,40 +311,130 @@ def _chunk_state_fwd_kernel( states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) + @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), ], key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit def _chunk_state_varlen_kernel( # Pointers to matrices - x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr, + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, # Matrix dimensions - hdim, dstate, chunk_size, - seqlen, nheads_ngroups_ratio, + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, # Strides - stride_x_seqlen, stride_x_head, stride_x_hdim, - stride_b_seqlen, stride_b_head, stride_b_dstate, - stride_dt_chunk, stride_dt_head, stride_dt_csize, - stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, - stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate, - stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) @@ -212,7 +443,8 @@ def _chunk_state_varlen_kernel( pid_n = tl.program_id(axis=0) % num_pid_n end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) pid_c = (end_idx - 1) // chunk_size - b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head @@ -221,10 +453,13 @@ def _chunk_state_varlen_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize chunk_size_limit = end_idx - pid_c * chunk_size @@ -233,12 +468,24 @@ def _chunk_state_varlen_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k) & + (offs_k[None, :] >= start_idx_cur - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate) & + (offs_k[:, None] >= start_idx_cur - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -249,8 +496,13 @@ def _chunk_state_varlen_kernel( # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk if start_idx < pid_c * chunk_size: - chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate) - chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + chunk_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + chunk_states = tl.load(chunk_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0) scale = tl.exp(dA_cs_last) acc += chunk_states * scale @@ -260,37 +512,77 @@ def _chunk_state_varlen_kernel( states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) -def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): +def _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): batch, seqlen, nheads = dt.shape - assert A.shape == (nheads,) + assert A.shape == (nheads, ) if dt_bias is not None: - assert dt_bias.shape == (nheads,) + assert dt_bias.shape == (nheads, ) nchunks = math.ceil(seqlen / chunk_size) - dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + dt_out = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + dA_cumsum = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, + triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( - dt, A, dt_bias, dt_out, dA_cumsum, - batch, seqlen, nheads, chunk_size, - dt_limit[0], dt_limit[1], - dt.stride(0), dt.stride(1), dt.stride(2), + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), A.stride(0), dt_bias.stride(0) if dt_bias is not None else 0, - dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out -def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True): + +def _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=None, + states=None, + states_in_fp32=True): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape @@ -304,24 +596,54 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_f assert states.shape == (batch, nchunks, nheads, headdim, dstate) else: states_dtype = torch.float32 if states_in_fp32 else B.dtype - states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) - grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), - batch * nchunks, nheads) + states = torch.empty((batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype) + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( + dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_state_fwd_kernel[grid]( - x, B, states, dt, dA_cumsum, seq_idx, - headdim, dstate, chunk_size, - batch, seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), x.stride(3), - B.stride(0), B.stride(1), B.stride(2), B.stride(-1), - states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), - dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), - dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_SEQ_IDX=seq_idx is not None, ) return states + def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): total_seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape @@ -333,19 +655,47 @@ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert chunk_states.shape == (nchunks, nheads, headdim, dstate) - states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device) - grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), - batch, nheads) + states = torch.empty(batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) with torch.cuda.device(x.device.index): _chunk_state_varlen_kernel[grid]( - x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states, - headdim, dstate, chunk_size, - total_seqlen, nheads // ngroups, - x.stride(0), x.stride(1), x.stride(2), - B.stride(0), B.stride(1), B.stride(2), - dt.stride(1), dt.stride(0), dt.stride(2), - dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2), - chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3), - states.stride(0), states.stride(1), states.stride(2), states.stride(3), + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), ) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 728024a6b31fa..90854fd0c0a10 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -1,50 +1,67 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_combined.py +# ruff: noqa: E501 """We want triton==2.1.0 or 2.2.0 for this """ -from packaging import version - import torch - import triton - from einops import rearrange +from packaging import version from .ssd_bmm import _bmm_chunk_fwd -from .ssd_chunk_state import _chunk_cumsum_fwd -from .ssd_chunk_state import _chunk_state_fwd -from .ssd_chunk_state import chunk_state_varlen -from .ssd_state_passing import _state_passing_fwd from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, + chunk_state_varlen) +from .ssd_state_passing import _state_passing_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + def init_to_zero(names): - return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + return lambda nargs: [ + nargs[name].zero_() for name in names if nargs[name] is not None + ] -def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + +def _mamba_chunk_scan_combined_fwd(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) assert x.shape == (batch, seqlen, nheads, headdim) assert dt.shape == (batch, seqlen, nheads) - assert A.shape == (nheads,) + assert A.shape == (nheads, ) assert C.shape == B.shape if z is not None: assert z.shape == x.shape if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() - if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + if x.stride(-1) != 1 and x.stride( + 1) != 1: # Either M or K dimension should be contiguous x = x.contiguous() - if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + if z is not None and z.stride(-1) != 1 and z.stride( + 1) != 1: # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() @@ -54,28 +71,73 @@ def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, d # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) - states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + dA_cumsum, dt = _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + states = _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=seq_idx, + states_in_fp32=True) # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) - states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], - initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, - seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype) - states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None else None, + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=C.dtype) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) - CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) - out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) + CB = _bmm_chunk_fwd(C, + B, + chunk_size, + seq_idx=seq_idx, + output_dtype=torch.float32) + out, out_x = _chunk_scan_fwd(CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx) if cu_seqlens is None: return out, out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0), + varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), + dt.squeeze(0), dA_cumsum.squeeze(0), cu_seqlens, states.squeeze(0)) return out, out_x, dt, dA_cumsum, states, final_states, varlen_states -def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): + +def mamba_chunk_scan_combined(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -99,9 +161,26 @@ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bia cu_seqlens = None else: assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" - out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit) if not return_varlen_states: return out if not return_final_states else (out, final_states) else: varlen_states = rest[0] - return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) \ No newline at end of file + return (out, + varlen_states) if not return_final_states else (out, + final_states, + varlen_states) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 59ed1d17cfda2..dfc87fc7e5c68 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -1,10 +1,11 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_state_passing.py +# ruff: noqa: E501 """We want triton==2.1.0 or 2.2.0 for this """ import torch - import triton import triton.language as tl @@ -23,16 +24,37 @@ @triton.jit def _state_passing_fwd_kernel( # Pointers to matrices - states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, # Matrix dimensions - dim, nchunks, seqlen, chunk_size, + dim, + nchunks, + seqlen, + chunk_size, # Strides - stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim, - stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, - stride_final_states_batch, stride_final_states_head, stride_final_states_dim, - stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, - stride_initstates_batch, stride_initstates_head, stride_initstates_dim, - stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, # Meta-parameters HAS_INITSTATES: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, @@ -59,16 +81,20 @@ def _state_passing_fwd_kernel( states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) else: initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim - states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk seq_idx = 0 for c in range(nchunks): - new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + new_states = tl.load(states_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale = tl.exp(dA_cs) if HAS_SEQ_IDX: - seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) + seq_idx_new = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) seq_idx = seq_idx_new states = scale * states + new_states @@ -81,7 +107,11 @@ def _state_passing_fwd_kernel( out_ptrs += stride_out_chunk -def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, +def _state_passing_fwd(states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, out_dtype=None): batch, nchunks, nheads, dim = states.shape assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) @@ -92,20 +122,44 @@ def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=Non seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype - out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) - final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32) + out = torch.empty((batch, nchunks, nheads, dim), + device=states.device, + dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), + device=states.device, + dtype=torch.float32) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( - states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx, - dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, - states.stride(0), states.stride(1), states.stride(2), states.stride(3), - out.stride(0), out.stride(1), out.stride(2), out.stride(3), - final_states.stride(0), final_states.stride(1), final_states.stride(2), - dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), - *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) - if initial_states is not None else (0, 0, 0)), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + states, + out, + final_states, + dA_chunk_cumsum, + initial_states, + seq_idx, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size if seq_idx is not None else 0, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_chunk_cumsum.stride(0), + dA_chunk_cumsum.stride(2), + dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2)) if initial_states is not None else + (0, 0, 0)), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_INITSTATES=initial_states is not None, HAS_SEQ_IDX=seq_idx is not None, ) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 5c6a8ab043170..2693c45b27520 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -10,16 +10,16 @@ from vllm.attention.layer import Attention from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -67,6 +67,7 @@ def forward(self, x): x, _ = self.down_proj(x) return x + class BambaMixerDecoderLayer(nn.Module): def __init__(self, @@ -161,7 +162,7 @@ def __init__( max_position_embeddings=max_position_embeddings, base=rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), # see impl of get_rope + dtype=torch.get_default_dtype(), # see impl of get_rope ) self.qkv_proj = QKVParallelLinear( @@ -203,23 +204,28 @@ def self_attention( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # because the bamba model may potentially handle long sequences, - # we should adjust the sin_cos cache if necesary to avoid out of bounds + # because the bamba model may potentially handle long sequences, + # we should adjust the sin_cos cache if necessary to avoid out of bounds # - first get the max_position max_position = max( getattr(attn_metadata, 'max_prefill_seq_len', 0), getattr(attn_metadata, 'max_decode_seq_len', 0), ) if max_position == 0: - # if we cannot get the max lenght from the metadata, then - # get it frmo the positions + # if we cannot get the max length from the metadata, then + # get it from the positions max_position = positions.max().item() - if self.rotary_emb.max_position_embeddings <= max_position: + # when VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 could potentially cause inputs + # longer than max_position_embeddings. We extend the rope cache + # to prevent CUDA errors. Be aware that the outputs could be of + # lower quality for long sequence lengths. + rotary = self.rotary_emb + if rotary.max_position_embeddings <= max_position: # we set it to the next power of two that covers it - while self.rotary_emb.max_position_embeddings <= max_position: - self.rotary_emb.max_position_embeddings *= 2 - self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache() + while rotary.max_position_embeddings <= max_position: + rotary.max_position_embeddings *= 2 + rotary.cos_sin_cache = rotary._compute_cos_sin_cache() q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) @@ -260,6 +266,7 @@ def forward( "mamba": BambaMixerDecoderLayer } + class BambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -312,10 +319,11 @@ def forward( # add additional attn_metadata for the mixer layers if attn_metadata.num_prefills > 0: sed_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate(zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): sed_idx[srt:end] = i attn_metadata.seq_idx = sed_idx @@ -335,7 +343,8 @@ def forward( layer_mamba_cache_params = None if isinstance(layer, BambaMixerDecoderLayer): - layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i - num_attn) + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) hidden_states, residual = layer( positions=positions, @@ -457,18 +466,14 @@ def _get_mamba_cache_shape( intermediate_size = self.config.mamba_expand * hidden_size - # if n_groups is not divisible by world_size, need to extend the shards to ensure - # all groups needed by a head is sharded along with it - n_groups = ( - self.config.mamba_n_groups + - extra_groups_for_head_shards(self.config.mamba_n_groups, world_size) - ) + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) # - heads and n_groups are TP-ed - conv_dim = ( - intermediate_size + - 2 * n_groups * self.config.mamba_d_state - ) + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) conv_state_shape = ( divide(conv_dim, world_size), self.config.mamba_d_conv - 1, From 9ad9e20723c2a7e46ce1aee9759424f0ea64b03c Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 07:11:31 +0000 Subject: [PATCH 07/19] more comments Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_bamba.py | 2 -- .../layers/mamba/ops/ssd_combined.py | 32 +++++++++++++------ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py index d266135360563..96efdc59081d3 100644 --- a/tests/models/decoder_only/language/test_bamba.py +++ b/tests/models/decoder_only/language/test_bamba.py @@ -20,8 +20,6 @@ # choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. def generate_greedy(model_name, example_prompts, max_tokens): # Create a text generation pipeline - # - in the original test_mamba.py they do not put the model to cuda - # maybe this affects the test. tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 90854fd0c0a10..579663a76fb7b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -67,25 +67,36 @@ def _mamba_chunk_scan_combined_fwd(x, D = D.contiguous() if initial_states is not None: assert initial_states.shape == (batch, nheads, headdim, dstate) - # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size) - # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) - # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) - # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) - # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], @@ -96,13 +107,16 @@ def _mamba_chunk_scan_combined_fwd(x, out_dtype=C.dtype) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) - # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) - # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + + # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. out, out_x = _chunk_scan_fwd(CB, x, dt, From 25bf3810b0b30e892574ef6a83b949f5b0898903 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 07:41:31 +0000 Subject: [PATCH 08/19] initial fix for chunked prefill (incomplete) Signed-off-by: Yu Chin Fabian Lim --- .../layers/mamba/ops/ssd_combined.py | 40 ++++++++++---- .../layers/mamba/ops/ssd_state_passing.py | 55 ++++++++++++++++--- 2 files changed, 75 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 579663a76fb7b..03eaec168076d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -66,7 +66,11 @@ def _mamba_chunk_scan_combined_fwd(x, if D is not None and D.stride(-1) != 1: D = D.contiguous() if initial_states is not None: - assert initial_states.shape == (batch, nheads, headdim, dstate) + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, + headdim, dstate) # This function executes 5 sub-functions for computing mamba # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ @@ -97,6 +101,11 @@ def _mamba_chunk_scan_combined_fwd(x, # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx and iii) has_cu_seqlens to be all specified. + # - When a new seq_idx is detected, we will load the correct initial_state + # and ensure that the output states is correctly updated. + # states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], @@ -104,7 +113,8 @@ def _mamba_chunk_scan_combined_fwd(x, if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size, - out_dtype=C.dtype) + out_dtype=C.dtype, + has_cu_seqlens=cu_seqlens is not None) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) @@ -117,15 +127,23 @@ def _mamba_chunk_scan_combined_fwd(x, # 5. Scan and compute the diagonal blocks, taking into # account past causal states. - out, out_x = _chunk_scan_fwd(CB, - x, - dt, - dA_cumsum, - C, - states, - D=D, - z=z, - seq_idx=seq_idx) + # - NOTE: in addition to the logic in _state_passing_fwd to handle + # chunked prefill, we also need to modify _chunk_scan_fwd to + # - the updates to _state_passing_fwd only handles initial_state + # if the sequences are synced to the chunk boundaries. + # - but in the case where there are offsets from the chunk boundaries + # we need to further update _chunk_scan_fwd (not yet done). + out, out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=(None if cu_seqlens is not None and initial_states is not None + else seq_idx)) if cu_seqlens is None: return out, out_x, dt, dA_cumsum, states, final_states else: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index dfc87fc7e5c68..a4bc87df0e755 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -58,6 +58,7 @@ def _state_passing_fwd_kernel( # Meta-parameters HAS_INITSTATES: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, + HAS_CU_SEQLENS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid_b = tl.program_id(axis=1) @@ -68,7 +69,10 @@ def _state_passing_fwd_kernel( out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: - initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head + initstates_ptr += pid_h * stride_initstates_head + if not HAS_CU_SEQLENS: + initstates_ptr += pid_b * stride_initstates_batch + if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch @@ -95,7 +99,25 @@ def _state_passing_fwd_kernel( seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + if HAS_INITSTATES: + if HAS_CU_SEQLENS and seq_idx != seq_idx_new: + # need to load the initial state for this new sequence + # - override the scanned state + initstates_ptrs += seq_idx_new * stride_initstates_batch + + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) + + # in the previous scan iteration, the wrong state was + # written to the output buffer + # - so we also override it + tl.store(out_ptrs - stride_out_chunk, + states, + mask=offs_m < dim) + else: + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + seq_idx = seq_idx_new states = scale * states + new_states if c < nchunks - 1: @@ -107,16 +129,30 @@ def _state_passing_fwd_kernel( out_ptrs += stride_out_chunk -def _state_passing_fwd(states, - dA_chunk_cumsum, - initial_states=None, - seq_idx=None, - chunk_size=None, - out_dtype=None): +def _state_passing_fwd( + states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + has_cu_seqlens=False, +): batch, nchunks, nheads, dim = states.shape assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) if initial_states is not None: - assert initial_states.shape == (batch, nheads, dim) + if has_cu_seqlens: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert seq_idx is not None, "" + assert initial_states.shape == (seq_idx.max().item() + 1, nheads, + dim) + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + if seq_idx is not None: assert chunk_size is not None seqlen = seq_idx.shape[-1] @@ -162,5 +198,6 @@ def _state_passing_fwd(states, seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_INITSTATES=initial_states is not None, HAS_SEQ_IDX=seq_idx is not None, + HAS_CU_SEQLENS=has_cu_seqlens, ) return out, final_states From 43ce07cb8556f1d30ef27c845201c8c5fef6384f Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 12:59:04 +0000 Subject: [PATCH 09/19] improve comments Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_bamba.py | 14 ++++--- .../layers/mamba/mamba_mixer2.py | 39 ++++++++++++------- .../layers/mamba/ops/ssd_combined.py | 2 +- .../layers/mamba/ops/ssd_state_passing.py | 12 +++--- 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py index 96efdc59081d3..164bd8d40e03e 100644 --- a/tests/models/decoder_only/language/test_bamba.py +++ b/tests/models/decoder_only/language/test_bamba.py @@ -1,9 +1,3 @@ -"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. - -This actually is really identical to test_mamba, so maybe we can reuse - -Run `pytest tests/models/decoder_only/language/test_bamba.py`. -""" import pytest from transformers import AutoModelForCausalLM, AutoTokenizer @@ -40,6 +34,14 @@ def generate_greedy(model_name, example_prompts, max_tokens): return outputs +"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. + +This actually is really identical to test_mamba, so maybe we can reuse + +Run `pytest tests/models/decoder_only/language/test_bamba.py`. +""" + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 150ee86b4ca3b..72e574a12c52f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -85,7 +85,8 @@ def forward_cuda( def extra_groups_for_head_shards(ngroups: int, tp_size: int): - """Compute the extra (logical) groups to account for head shards""" + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" # in the case ngoups % tp_size == 0, this will be zero if ngroups % tp_size == 0: @@ -109,22 +110,29 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - track boundary of (sharded) param, and loaded_weight, respectively boundary, loaded_boundary = 0, 0 - for full_dim, extra, ratio in shard_spec: - # - full dim is the expected size of the model - # - if extra > 0, this means there was some expansion - # - num of dims expected to be loaded + # - iterate over the shard specs + for full_dim, extra, ratio in shard_spec: + # - full dim is the model dim (before TP). + # - extra > 0, means there is expected overall increase + # of dimensions. This is so because of replication. + # - ratio is used map the tp_rank to the actual shard + # rank. This is useful when there is replication of + # groups to accompany head shards. + + # - size of the loaded shard shard_size = full_dim // tp_size - # - compute where to take the loaded shard from + # - compute the rank into the loaded shard. + # - if there is replication, different TP shards will + # take from the same rank. rank = tp_rank // ratio - # - should start from here (determined by rank) - # - take these number dims from loaded + # - leftmost boundary index into loaded weight. loaded_skip = rank * shard_size loaded_start_idx = loaded_boundary + loaded_skip - # - these many number dims to take from loaded_weight + # - take these many dims from the loaded weight. take = min(shard_size, full_dim - extra - loaded_skip) # - always shard on dim 0 @@ -136,7 +144,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: loaded_start_idx:( # type: ignore[misc] loaded_start_idx + take)] # type: ignore[misc] - # move boundaries + # move indexing boundaries boundary += shard_size loaded_boundary += (full_dim - extra) @@ -169,6 +177,7 @@ def __init__(self, head_dim: int = 64, rms_norm_eps: float = 1e-5, activation="silu", + chunk_size: int = 256, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -178,12 +187,12 @@ def __init__(self, # we shard intermediate_size and n_groups # - since intermediate_size = n_heads * head_dim, sharding on # intermediate_size is achieved by sharding on n_heads. - # - so if world_size divides groups, then sharding + # - IF, world_size divides groups, then sharding # (n_groups / world_size, n_heads / world_size) # also maintains the invariant n_heads % n_groups == 0 - # - HOWEVER< if world_size DOES NOT divide groups, then we need - # to allocate extra space in the shard, such that the WHOLE GROUP - # must be placed together with the HEAD SHARD. + # - HOWEVER IF, world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that groups + # may be replicated to follow the head shard. self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -191,7 +200,7 @@ def __init__(self, self.use_rms_norm = use_rms_norm self.activation = activation - self.chunk_size = 256 + self.chunk_size = chunk_size self.intermediate_size = intermediate_size self.head_dim = head_dim self.num_heads = num_heads diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 03eaec168076d..a9b6c79496ab9 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -114,7 +114,7 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype, - has_cu_seqlens=cu_seqlens is not None) + is_cont_batched=cu_seqlens is not None) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index a4bc87df0e755..174b21d73b85a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -58,7 +58,7 @@ def _state_passing_fwd_kernel( # Meta-parameters HAS_INITSTATES: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, - HAS_CU_SEQLENS: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid_b = tl.program_id(axis=1) @@ -70,7 +70,7 @@ def _state_passing_fwd_kernel( final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: initstates_ptr += pid_h * stride_initstates_head - if not HAS_CU_SEQLENS: + if not IS_CONT_BATCHED: initstates_ptr += pid_b * stride_initstates_batch if HAS_SEQ_IDX: @@ -100,7 +100,7 @@ def _state_passing_fwd_kernel( (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) if HAS_INITSTATES: - if HAS_CU_SEQLENS and seq_idx != seq_idx_new: + if IS_CONT_BATCHED and seq_idx != seq_idx_new: # need to load the initial state for this new sequence # - override the scanned state initstates_ptrs += seq_idx_new * stride_initstates_batch @@ -136,12 +136,12 @@ def _state_passing_fwd( seq_idx=None, chunk_size=None, out_dtype=None, - has_cu_seqlens=False, + is_cont_batched=False, ): batch, nchunks, nheads, dim = states.shape assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) if initial_states is not None: - if has_cu_seqlens: + if is_cont_batched: # - if cu_seqlens is provided, then the initial states # are used for continuous batching. In which case we # require seq_idx to be provided @@ -198,6 +198,6 @@ def _state_passing_fwd( seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_INITSTATES=initial_states is not None, HAS_SEQ_IDX=seq_idx is not None, - HAS_CU_SEQLENS=has_cu_seqlens, + IS_CONT_BATCHED=is_cont_batched, ) return out, final_states From 80f14b539d4ea3f883a72b4996cf4f718334e084 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 15:01:15 +0000 Subject: [PATCH 10/19] do not attach seq_idx to attn_metadata Signed-off-by: Yu Chin Fabian Lim --- .../layers/mamba/mamba_mixer2.py | 14 ++++++++----- vllm/model_executor/models/bamba.py | 20 ++++++++++++------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 72e574a12c52f..1b43664875aed 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -191,7 +191,7 @@ def __init__(self, # (n_groups / world_size, n_heads / world_size) # also maintains the invariant n_heads % n_groups == 0 # - HOWEVER IF, world_size DOES NOT divide groups, then we need - # to allocate extra space in the shard, such that groups + # to allocate extra space in the shard, such that groups # may be replicated to follow the head shard. self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -322,9 +322,13 @@ def forward_native(self, hidden_states: torch.Tensor, conv_state: torch.Tensor, ssm_state: torch.Tensor): pass - def forward_cuda(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams): + def forward_cuda( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + ): seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size @@ -423,7 +427,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, D=self.D, z=None, dt_bias=self.dt_bias, - seq_idx=attn_metadata.seq_idx.unsqueeze(0), + seq_idx=sequence_idx, cu_seqlens=attn_metadata.query_start_loc, initial_states=initial_states, return_varlen_states=True, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 2693c45b27520..7ec2a26254d46 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -102,6 +102,7 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, **kwargs, ): if residual is None: @@ -112,7 +113,7 @@ def forward( hidden_states, residual) hidden_states = self.mamba(hidden_states, attn_metadata, - mamba_cache_params) + mamba_cache_params, sequence_idx) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -316,17 +317,19 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # add additional attn_metadata for the mixer layers + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None if attn_metadata.num_prefills > 0: - sed_idx = torch.zeros_like(input_ids, dtype=torch.int32) + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) for i, (srt, end) in enumerate( zip( attn_metadata.query_start_loc, attn_metadata.query_start_loc[1:], )): - sed_idx[srt:end] = i - - attn_metadata.seq_idx = sed_idx + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) if inputs_embeds is not None: hidden_states = inputs_embeds @@ -352,7 +355,9 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, residual=residual, - mamba_cache_params=layer_mamba_cache_params) + mamba_cache_params=layer_mamba_cache_params, + sequence_idx=seq_idx, + ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -364,6 +369,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): "k_proj", "v_proj", ], + "gate_up_proj": ["up_proj", "down_proj"] } # LoRA specific attributes From 6b8ac4910512b48772bcbc74838215c54fbc21de Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 15:15:26 +0000 Subject: [PATCH 11/19] activate initial states for chunked prefill Signed-off-by: Yu Chin Fabian Lim --- .../layers/mamba/mamba_mixer2.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 1b43664875aed..927103212d6c1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -4,6 +4,7 @@ from torch import nn from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -344,6 +345,13 @@ def forward_cuda( # * "context_lens_tensor" = [8, ...] has_prefill = attn_metadata.num_prefills > 0 + # - also need flags to indicate if there are initial states + # - currently we really only support the FlashAttention backend + has_initial_states = None + if (isinstance(attn_metadata, FlashAttentionMetadata) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = attn_metadata.context_lens_tensor > 0 + # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) gate, hidden_states_B_C, dt = torch.split( @@ -376,7 +384,7 @@ def forward_cuda( self.conv1d.bias, activation=self.activation, conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, + has_initial_state=has_initial_states, cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc).transpose( 0, 1)[:seq_len] @@ -404,17 +412,14 @@ def forward_cuda( if has_prefill: # FIXME: we are having problems using mamba_chunk_scan_combined - # with chunked prefill. This is because there is no - # initial_states requires initial_states.shape[0] to match - # the batch size, but cu_seqlens requires batch_size = 1. - # Therefore as of now, initial_states and cu_seqlens are - # mutually exclusive. + # with chunked prefill. This is because currently + # chunked_prefill only works if "attn_metadata.query_start_loc" + # is aligned with chunk_size. WIP initial_states = None - # if any(attn_metadata.context_lens_tensor > 0): - # initial_states = mamba_cache_params.ssm_state[ - # mamba_cache_params.state_indices_tensor - # ] + if has_initial_states is not None and any(has_initial_states): + initial_states = mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor] scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states.view(1, seq_len, self.num_heads // self.tp_size, From d788db694330d27ffc0f269dfbb1ccef3eb82f72 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 13 Dec 2024 01:08:42 +0000 Subject: [PATCH 12/19] reuse softplus and remove triton2 remark Signed-off-by: Yu Chin Fabian Lim --- .../layers/mamba/ops/softplus.py | 21 ------------------- .../layers/mamba/ops/ssd_bmm.py | 9 -------- .../layers/mamba/ops/ssd_chunk_scan.py | 9 -------- .../layers/mamba/ops/ssd_chunk_state.py | 11 +--------- .../layers/mamba/ops/ssd_combined.py | 8 ------- .../layers/mamba/ops/ssd_state_passing.py | 2 -- 6 files changed, 1 insertion(+), 59 deletions(-) delete mode 100644 vllm/model_executor/layers/mamba/ops/softplus.py diff --git a/vllm/model_executor/layers/mamba/ops/softplus.py b/vllm/model_executor/layers/mamba/ops/softplus.py deleted file mode 100644 index 5ec75be51bf3b..0000000000000 --- a/vllm/model_executor/layers/mamba/ops/softplus.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2024, Tri Dao, Albert Gu. -# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/softplus.py - -# ruff: noqa: E501 - -import triton -import triton.language as tl -from packaging import version - -TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") - -if TRITON3: - - @triton.jit - def softplus(dt): - return tl.math.log(tl.math.exp(dt) + 1) -else: - - @triton.jit - def softplus(dt): - return tl.math.log1p(tl.exp(dt)) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 3eba3c49b4590..5560f47b9d34c 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -2,8 +2,6 @@ # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_bmm.py # ruff: noqa: E501,SIM102 -"""We want triton==2.1.0 or 2.2.0 for this -""" import math @@ -11,13 +9,6 @@ import triton import triton.language as tl - -def init_to_zero(names): - return lambda nargs: [ - nargs[name].zero_() for name in names if nargs[name] is not None - ] - - @triton.autotune( configs=[ triton.Config( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index c538aaa464171..226efad6b8fd1 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -2,8 +2,6 @@ # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_chunk_scan.py # ruff: noqa: E501 -"""We want triton==2.1.0 or 2.2.0 for this -""" import torch import triton @@ -12,13 +10,6 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') - -def init_to_zero(names): - return lambda nargs: [ - nargs[name].zero_() for name in names if nargs[name] is not None - ] - - @triton.autotune( configs=[ triton.Config( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index bafdcd2585e5a..551c56a6bb691 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -2,8 +2,6 @@ # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_chunk_state.py # ruff: noqa: E501 -"""We want triton==2.1.0 or 2.2.0 for this -""" import math @@ -11,14 +9,7 @@ import triton import triton.language as tl -from .softplus import softplus - - -def init_to_zero(names): - return lambda nargs: [ - nargs[name].zero_() for name in names if nargs[name] is not None - ] - +from .mamba_ssm import softplus @triton.autotune( configs=[ diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index a9b6c79496ab9..9b5e18368530d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -2,8 +2,6 @@ # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_combined.py # ruff: noqa: E501 -"""We want triton==2.1.0 or 2.2.0 for this -""" import torch import triton @@ -19,12 +17,6 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') -def init_to_zero(names): - return lambda nargs: [ - nargs[name].zero_() for name in names if nargs[name] is not None - ] - - def _mamba_chunk_scan_combined_fwd(x, dt, A, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 174b21d73b85a..5b44ce07a4b85 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -2,8 +2,6 @@ # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_state_passing.py # ruff: noqa: E501 -"""We want triton==2.1.0 or 2.2.0 for this -""" import torch import triton From 400db27d7367a3ad2fdce4d3487c818b1237fee3 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 13 Dec 2024 01:11:37 +0000 Subject: [PATCH 13/19] add comment on weight loader and format Signed-off-by: Yu Chin Fabian Lim --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 5 ++++- vllm/model_executor/layers/mamba/ops/ssd_bmm.py | 1 + vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py | 1 + vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 927103212d6c1..2b019cc702338 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -249,6 +249,10 @@ def __init__(self, intemediate_settings = (intermediate_size, 0, 1) head_setings = (self.num_heads, 0, 1) + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the otther two weights below delattr(self.conv1d.bias, "weight_loader") set_weight_attrs( self.conv1d.bias, { @@ -450,7 +454,6 @@ def forward_cuda( hidden_states = scan_output.view(seq_len, -1) else: - # NOTE: can be optimized? n_groups = self.n_groups // self.tp_size A = self.A[:, None, ...][:, :, None].expand( -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 5560f47b9d34c..a1f7fb06c0e17 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -9,6 +9,7 @@ import triton import triton.language as tl + @triton.autotune( configs=[ triton.Config( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 226efad6b8fd1..ee73720ad7096 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -10,6 +10,7 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + @triton.autotune( configs=[ triton.Config( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 551c56a6bb691..f280aaa9e3021 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -11,6 +11,7 @@ from .mamba_ssm import softplus + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_H': 1}), From bda8ea7ff84fe71ccc58a0dfcdeddecb6f11bf17 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 13 Dec 2024 03:32:19 +0000 Subject: [PATCH 14/19] rename test_jamba to test_hybrid and got rid of test_bamba Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_bamba.py | 329 ------------------ .../{test_jamba.py => test_hybrid.py} | 17 +- 2 files changed, 9 insertions(+), 337 deletions(-) delete mode 100644 tests/models/decoder_only/language/test_bamba.py rename tests/models/decoder_only/language/{test_jamba.py => test_hybrid.py} (95%) diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py deleted file mode 100644 index 164bd8d40e03e..0000000000000 --- a/tests/models/decoder_only/language/test_bamba.py +++ /dev/null @@ -1,329 +0,0 @@ -import pytest -from transformers import AutoModelForCausalLM, AutoTokenizer - -from vllm.config import VllmConfig -from vllm.sampling_params import SamplingParams - -from ...utils import check_outputs_equal - -# will be ch -MODELS = ["ibm-fms/Bamba-9.8b-1.8T-hf"] - - -# Use lower-level interfaces to create this greedy generator, as mamba will -# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. -def generate_greedy(model_name, example_prompts, max_tokens): - # Create a text generation pipeline - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name) - - # Generate texts from the prompts - outputs = [] - for prompt in example_prompts: - # Tokenize the input prompt with truncation - inputs = tokenizer(prompt, return_tensors="pt", truncation=True) - input_ids = inputs["input_ids"] - - # Generate text using the model's generate method directly - generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) - generated_text = tokenizer.decode(generated_ids[0], - skip_special_tokens=True) - - outputs.append((generated_ids[0].tolist(), generated_text)) - - return outputs - - -"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. - -This actually is really identical to test_mamba, so maybe we can reuse - -Run `pytest tests/models/decoder_only/language/test_bamba.py`. -""" - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_models( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - hf_outputs = generate_greedy(model, example_prompts, max_tokens) - - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_batching( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # To pass the small model tests, we need full precision. - for_loop_outputs = [] - with vllm_runner(model, dtype=dtype) as vllm_model: - for prompt in example_prompts: - for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0]) - - batched_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=for_loop_outputs, - outputs_1_lst=batched_outputs, - name_0="for_loop_vllm", - name_1="batched_vllm", - ) - - -@pytest.mark.skip("bamba does not support chunked prefill yet") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [10]) -def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, - model: str, dtype: str, - max_tokens: int) -> None: - # Tests chunked prefill in conjunction with n>1. In this case, prefill is - # populated with decoding tokens and we test that it doesn't fail. - # This test might fail if cache is not allocated correctly for n > 1 - # decoding steps inside a chunked prefill forward pass (where we have both - # prefill and decode together ) - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) - with vllm_runner( - model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=30, - max_num_seqs=10 # forces prefill chunks with decoding - ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) - - -@pytest.mark.skip("bamba does not support chunked prefill yet") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, - chunked_prefill_token_size: int) -> None: - """ - Checks exact match decode between huggingface model and vllm runner with - chunked prefill. - """ - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - non_chunked = generate_greedy(model, example_prompts, max_tokens) - - with vllm_runner(model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy(example_prompts, - max_tokens=max_tokens) - - check_outputs_equal( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [15]) -def test_parallel_sampling( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - - with vllm_runner(model, dtype=dtype) as vllm_model: - for_loop_outputs = [] - for _ in range(10): - for_loop_outputs.append( - # using example_prompts index 1 instead of 0 since with 0 the - # logprobs get really close and the test doesn't pass - vllm_model.generate_greedy([example_prompts[1]], max_tokens) - [0]) - sampling_params = SamplingParams(n=10, - temperature=0.001, - seed=0, - max_tokens=max_tokens) - n_lt_1_outputs = vllm_model.generate([example_prompts[1]], - sampling_params) - token_ids, texts = n_lt_1_outputs[0] - n_lt_1_outputs = [(token_id, text) - for token_id, text in zip(token_ids, texts)] - - check_outputs_equal( - outputs_0_lst=n_lt_1_outputs, - outputs_1_lst=for_loop_outputs, - name_0="vllm_n_lt_1_outputs", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_mamba_cache_cg_padding( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # This test is for verifying that mamba cache is padded to CG captured - # batch size. If it's not, a torch RuntimeError will be raised because - # tensor dimensions aren't compatible - while len(example_prompts) == VllmConfig.get_graph_batch_size( - len(example_prompts)): - example_prompts.append(example_prompts[0]) - - try: - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) - except RuntimeError: - pytest.fail( - "Couldn't run batch size which is not equal to a Cuda Graph " - "captured batch size. " - "Could be related to mamba cache not padded correctly") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_models_preemption_recompute( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # Tests that outputs are identical with and w/o preemtions (recompute) - assert dtype == "float" - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) - - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the Mamba inner state management doesn't - # collapse in case where the number of incoming requests and - # finished_requests_ids is larger than the maximum Mamba block capacity. - # This could generally happen due to the fact that Mamba does support - # statelessness mechanism where it can cleanup new incoming requests in - # a single step. - try: - with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 100, 10) - except ValueError: - pytest.fail("Mamba inner state wasn't cleaned up properly between" - "steps finished requests registered unnecessarily ") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_state_cleanup( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the Mamba state is cleaned up between - # steps, If its not cleaned, an error would be expected. - try: - with vllm_runner(model, dtype=dtype) as vllm_model: - for _ in range(10): - vllm_model.generate_greedy([example_prompts[0]] * 100, 1) - except ValueError: - pytest.fail("Mamba inner state wasn't cleaned up between states, " - "could be related to finished_requests_ids") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_multistep( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 10, 1) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -def test_multistep_correctness(vllm_runner, model: str, dtype: str, - max_tokens: int, example_prompts) -> None: - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_outputs_multistep = vllm_model.generate_greedy( - example_prompts, max_tokens) - - with vllm_runner(model, num_scheduler_steps=1, - max_num_seqs=2) as vllm_model: - vllm_outputs_single_step = vllm_model.generate_greedy( - example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_outputs_multistep, - outputs_1_lst=vllm_outputs_single_step, - name_0="vllm_outputs_multistep", - name_1="vllm_outputs_single_step", - ) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_hybrid.py similarity index 95% rename from tests/models/decoder_only/language/test_jamba.py rename to tests/models/decoder_only/language/test_hybrid.py index cae25ae9fa2c8..ce602f63af4e2 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -6,7 +6,8 @@ from ...utils import check_outputs_equal -MODELS = ["ai21labs/Jamba-tiny-dev"] +# This test is for the hybrid models +MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-fms/Bamba-9.8b-1.8T-hf"] @pytest.mark.parametrize("model", MODELS) @@ -140,7 +141,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [15]) def test_parallel_sampling( vllm_runner, @@ -243,17 +244,17 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( dtype: str, example_prompts, ) -> None: - # This test is for verifying that the Jamba inner state management doesn't + # This test is for verifying that the hybrid inner state management doesn't # collapse in case where the number of incoming requests and # finished_requests_ids is larger than the maximum mamba block capacity. - # This could generally happen due to the fact that Jamba does support + # This could generally happen due to the fact that hybrid does support # statelessness mechanism where it can cleanup new incoming requests in # a single step. try: with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up properly between" + pytest.fail("Hybrid inner state wasn't cleaned up properly between" "steps finished requests registered unnecessarily ") @@ -265,14 +266,14 @@ def test_state_cleanup( dtype: str, example_prompts, ) -> None: - # This test is for verifying that the Jamba state is cleaned up between + # This test is for verifying that the Hybrid state is cleaned up between # steps, If its not cleaned, an error would be expected. try: with vllm_runner(model, dtype=dtype) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up between states, " + pytest.fail("Hybrid inner state wasn't cleaned up between states, " "could be related to finished_requests_ids") @@ -318,7 +319,7 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str, @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) -def test_jamba_distributed_produces_identical_generation( +def test_hybrid_distributed_produces_identical_generation( vllm_runner, model: str, dtype: str, max_tokens: int, example_prompts) -> None: From a74de9f48d97a5cfd2c591782a78cd0d924bcb3b Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 16 Dec 2024 03:52:41 +0000 Subject: [PATCH 15/19] update bamba to ishybrid and support pp Signed-off-by: Yu Chin Fabian Lim --- vllm/model_executor/models/bamba.py | 98 ++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 7ec2a26254d46..dbee0cc283a06 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -8,8 +8,9 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -28,9 +29,12 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType -from .interfaces import HasInnerState, SupportsLoRA -from .utils import maybe_prefix +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -291,16 +295,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) - decoder_layers = [] - for i in range(config.num_hidden_layers): - layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] - decoder_layers.append( - layer_class(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{i}")) - self.layers = nn.ModuleList(decoder_layers) + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -314,6 +326,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -331,10 +344,17 @@ def forward( seq_idx[srt:end] = i seq_idx.unsqueeze_(0) - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + residual = None num_attn = 0 for i in range(len(self.layers)): @@ -358,11 +378,17 @@ def forward( mamba_cache_params=layer_mamba_cache_params, sequence_idx=seq_idx, ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states -class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -387,6 +413,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config @@ -419,6 +447,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # follow jamba + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + # for compilation + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + elif self.scheduler_config is not None: + # for eager just take the scheduler_config if avail + self.max_batch_size =self.scheduler_config.max_num_seqs + else: + self.max_batch_size = 8192 + 2 + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -431,16 +479,12 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (VllmConfig.get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) - layers_type = self.config.layers_block_type - num_mamba_layers = sum( - [layer_type == "mamba" for layer_type in layers_type]) + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) ( mamba_cache_tensors, @@ -452,6 +496,7 @@ def forward(self, state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, + intermediate_tensors, inputs_embeds) return hidden_states @@ -543,6 +588,9 @@ def load_weights(self, weights: Iterable[Tuple[str, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -551,6 +599,8 @@ def load_weights(self, weights: Iterable[Tuple[str, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", From b44caa7801debf0d60aab93069e431d4768ae446 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 16 Dec 2024 04:47:55 +0000 Subject: [PATCH 16/19] lint Signed-off-by: Yu Chin Fabian Lim --- vllm/model_executor/models/bamba.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index dbee0cc283a06..590887716c0aa 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -388,7 +388,8 @@ def forward( return hidden_states -class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -463,7 +464,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.scheduler_config.max_num_seqs) elif self.scheduler_config is not None: # for eager just take the scheduler_config if avail - self.max_batch_size =self.scheduler_config.max_num_seqs + self.max_batch_size = self.scheduler_config.max_num_seqs else: self.max_batch_size = 8192 + 2 @@ -484,8 +485,8 @@ def forward(self, self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, - *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) ( mamba_cache_tensors, state_indices_tensor, @@ -496,8 +497,7 @@ def forward(self, state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, - intermediate_tensors, - inputs_embeds) + intermediate_tensors, inputs_embeds) return hidden_states @@ -600,7 +600,7 @@ def load_weights(self, weights: Iterable[Tuple[str, if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): - continue + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", From 8cf364489d3b62039bc4bf172d6de362c1867b1c Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 16 Dec 2024 06:30:05 +0000 Subject: [PATCH 17/19] add unit test for mamba ssd Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/test_mamba_ssm_ssd.py | 124 ++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 tests/kernels/test_mamba_ssm_ssd.py diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py new file mode 100644 index 0000000000000..595520aa6f6e9 --- /dev/null +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -0,0 +1,124 @@ +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from vllm.model_executor.layers.mamba.ops.ssd_combined import mamba_chunk_scan_combined +from vllm.platforms import current_platform + +# Added by the IBM Team, 2024 + +# Adapted from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton + + +def segsum(x): + """More stable segment sum calculation.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): + """ + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + X, A, B, C = [ + rearrange(x, "b (c l) ... -> b c l ...", l=block_len) + for x in (X, A, B, C) + ] + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [4, 16, 32]) +@pytest.mark.parametrize("dim", [128, 512]) +def test_mamba_chunk_scan(dim, n_heads, itype): + device = "cuda" + # set seed + current_platform.seed_everything(0) + batch = 1 # batch_size + seqlen = 128 + chunk_size = 32 + d_head = dim // n_heads + + A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) + dt = F.softplus( + torch.randn(batch, seqlen, n_heads, dtype=itype, device=device) - 4) + X = torch.randn((batch, seqlen, n_heads, d_head), + dtype=itype, + device=device) + B = torch.randn((batch, seqlen, n_heads, d_head), + dtype=itype, + device=device) + C = torch.randn((batch, seqlen, n_heads, d_head), + dtype=itype, + device=device) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, + B, C, chunk_size) + + Y, final_state = mamba_chunk_scan_combined(X, + dt, + A, + B, + C, + chunk_size, + D=None, + return_final_states=True) + + # just test the last in sequence + torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=1e-2, rtol=1e1) + + # just test the last head + # NOTE, in the kernel we always cast states to fp32 + torch.testing.assert_close(final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=1e-2, + rtol=1e1) From e375b40eee3d71fca463fb3d807ed147abfd19ce Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 16 Dec 2024 06:35:54 +0000 Subject: [PATCH 18/19] fix lint Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/test_mamba_ssm_ssd.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py index 595520aa6f6e9..328a91459ff24 100644 --- a/tests/kernels/test_mamba_ssm_ssd.py +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -3,7 +3,8 @@ import torch.nn.functional as F from einops import rearrange, repeat -from vllm.model_executor.layers.mamba.ops.ssd_combined import mamba_chunk_scan_combined +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) from vllm.platforms import current_platform # Added by the IBM Team, 2024 @@ -39,10 +40,8 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): assert X.shape[1] % block_len == 0 # Rearrange into blocks/chunks - X, A, B, C = [ - rearrange(x, "b (c l) ... -> b c l ...", l=block_len) - for x in (X, A, B, C) - ] + X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) + for x in (X, A, B, C)) A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1) @@ -53,10 +52,11 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) - decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) - # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at + # chunk boundaries # (middle term of factorization of off-diag blocks; A terms) if initial_states is None: initial_states = torch.zeros_like(states[:, :1]) @@ -70,7 +70,8 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): state_decay_out = torch.exp(A_cumsum) Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Add output of intra-chunk and inter-chunk terms + # (diagonal and off-diagonal blocks) Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") return Y, final_state From dcbae7bea960af4867f07ecc5abbad6c2a51d896 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 21 Dec 2024 14:26:35 +0800 Subject: [PATCH 19/19] full chunked-prefill fix (sans unit tests) Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_hybrid.py | 6 +- .../layers/mamba/mamba_mixer2.py | 5 +- .../layers/mamba/ops/ssd_chunk_scan.py | 206 +++++++++++++++--- .../layers/mamba/ops/ssd_chunk_state.py | 80 ++++++- .../layers/mamba/ops/ssd_combined.py | 31 ++- .../layers/mamba/ops/ssd_state_passing.py | 26 ++- 6 files changed, 286 insertions(+), 68 deletions(-) diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index 22bbb39da0da0..3d1875322a282 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -7,7 +7,7 @@ from ...utils import check_outputs_equal # This test is for the hybrid models -MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-fms/Bamba-9.8b-1.8T-hf"] +MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-fms/Bamba-9B"] @pytest.mark.parametrize("model", MODELS) @@ -103,7 +103,7 @@ def test_mamba_prefill_chunking_with_parallel_sampling( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [10]) def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, @@ -111,6 +111,8 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, # numeric error during prefill chucking produces different generation # compared to w/o prefill chunking for those examples, removed them for now example_prompts.pop(7) + example_prompts.pop(6) + example_prompts.pop(5) example_prompts.pop(2) example_prompts.pop(1) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2b019cc702338..0b3f9f1028753 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -4,6 +4,7 @@ from torch import nn from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.xformers import XFormersMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -352,7 +353,7 @@ def forward_cuda( # - also need flags to indicate if there are initial states # - currently we really only support the FlashAttention backend has_initial_states = None - if (isinstance(attn_metadata, FlashAttentionMetadata) + if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata)) and attn_metadata.context_lens_tensor is not None): has_initial_states = attn_metadata.context_lens_tensor > 0 @@ -427,7 +428,7 @@ def forward_cuda( scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states.view(1, seq_len, self.num_heads // self.tp_size, - self.head_dim), + self.head_dim), dt.unsqueeze(0), self.A, B.view(1, seq_len, self.n_groups // self.tp_size, -1), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index ee73720ad7096..a548f11207baa 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -116,8 +116,12 @@ def _chunk_scan_fwd_kernel( dA_cumsum_ptr, seq_idx_ptr, C_ptr, - prev_states_ptr, + states_ptr, D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, # Matrix dimensions chunk_size, hdim, @@ -162,6 +166,10 @@ def _chunk_scan_fwd_kernel( stride_states_head, stride_states_hdim, stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, stride_D_head, # Meta-parameters IS_CAUSAL: tl.constexpr, @@ -174,62 +182,154 @@ def _chunk_scan_fwd_kernel( BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, ): pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + ( + cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + ( + x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_C_head - prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + + # - seq_idx_prev points to be previous (possibly logical) chunk. + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=pid_c>= 1, + other=0) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, + ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its inital state + # so this edge case is taken care of + if ( + (c_off == 0) and (seq_idx_prev != seq_idx_m) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # - handle chunk state limit + if HAS_INITSTATES: + + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c+1), + mask=pid_c > -1 and (pid_c+1) < chunk_meta_num, other=-1 # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs bounary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + + # get the next offset + c_off_n = tl.load( + chunk_offsets_ptr + (pid_c+1), + mask=pid_c > -1 and (pid_c+1) < chunk_meta_num, other=chunk_size + ) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + (pid_m * BLOCK_SIZE_M + c_off -1) * stride_dA_cs_csize, + mask=(pid_m * BLOCK_SIZE_M + c_off -1) > -1, + other=0.0).to(tl.float32) + if HAS_SEQ_IDX: - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=pid_c >= 1, - other=0) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Without the if (pid_c > -1), with Triton 2.1.0, I get # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. # With Triton 2.2.0, this works - if IS_TRITON_22 or pid_c > -1: + if IS_TRITON_22 or c_idx > -1: # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 offs_k_dstate = tl.arange( 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * stride_states_hdim + - offs_k_dstate[:, None] * stride_states_dstate) - if not HAS_SEQ_IDX: - scale_m = tl.exp(dA_cs_m) + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate) + if HAS_SEQ_IDX: + + if not HAS_INITSTATES: + # - this is for continous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # reqiured. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) else: - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + scale_m = tl.exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), @@ -254,7 +354,7 @@ def _chunk_scan_fwd_kernel( prev_states_ptrs += BLOCK_SIZE_K acc *= scale_m[:, None] - offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + @@ -291,7 +391,7 @@ def _chunk_scan_fwd_kernel( dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_D: @@ -309,7 +409,7 @@ def _chunk_scan_fwd_kernel( acc += x_residual * D if HAS_Z: - out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) tl.store(out_x_ptrs, @@ -317,7 +417,7 @@ def _chunk_scan_fwd_kernel( mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) - z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) z = tl.load(z_ptrs, @@ -326,7 +426,7 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) - out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) tl.store(out_ptrs, @@ -343,7 +443,9 @@ def _chunk_scan_fwd(cb, states, D=None, z=None, - seq_idx=None): + seq_idx=None, + initial_states=None, + ): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape @@ -357,8 +459,38 @@ def _chunk_scan_fwd(cb, assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + chunk_indices, chunk_offsets = None, None if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + assert initial_states.shape == (seq_idx[0].max()+1, nheads, headdim, dstate) + + if initial_states.shape[0] == 1: + # no in this case no point to use initial states + initial_states = None + else: + p = 0 + chunk_indices, chunk_offsets = [], [] + for i, idx in enumerate(seq_idx[0]): + o = i % chunk_size + c = idx > p + if o == 0 or c: + # this means we have a change in sequence + # - that does not accur on the chunk boundary + chunk_indices.append(i // chunk_size) + chunk_offsets.append(o) + + if c: + p = idx # new sequence + + chunk_indices = torch.tensor(chunk_indices, dtype=torch.int, device=seq_idx.device) + chunk_offsets = torch.tensor(chunk_offsets, dtype=torch.int, device=seq_idx.device) + # Allocates output. out = torch.empty(batch, seqlen, @@ -376,9 +508,14 @@ def _chunk_scan_fwd(cb, assert out_x.stride() == out.stride() else: out_x = None + + grid = lambda META: (triton.cdiv( chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + headdim, META['BLOCK_SIZE_N']), + batch * nchunks if chunk_offsets is None else len(chunk_offsets), + nheads + ) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel[grid]( @@ -393,6 +530,10 @@ def _chunk_scan_fwd(cb, C, states, D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, chunk_size, headdim, dstate, @@ -435,6 +576,12 @@ def _chunk_scan_fwd(cb, states.stride(2), states.stride(3), states.stride(4), + *( + ( + initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), initial_states.stride(3) + ) if initial_states is not None else (0, 0, 0, 0) + ), D.stride(0) if D is not None else 0, True, D is not None, @@ -443,5 +590,6 @@ def _chunk_scan_fwd(cb, HAS_Z=z is not None, HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, ) return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index f280aaa9e3021..731e350399b59 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -396,6 +396,7 @@ def _chunk_state_varlen_kernel( chunk_states_ptr, cu_seqlens_ptr, states_ptr, + initstates_ptr, # Matrix dimensions hdim, dstate, @@ -423,10 +424,15 @@ def _chunk_state_varlen_kernel( stride_states_head, stride_states_hdim, stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) @@ -442,6 +448,12 @@ def _chunk_state_varlen_kernel( dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -487,17 +499,49 @@ def _chunk_state_varlen_kernel( dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk - if start_idx < pid_c * chunk_size: - chunk_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim + - offs_n[None, :] * stride_chunk_states_dstate) - chunk_states = tl.load(chunk_states_ptrs, - mask=(offs_m[:, None] < hdim) & - (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) - # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0) - scale = tl.exp(dA_cs_last) - acc += chunk_states * scale + # If HAS_INITSTATES==True need to consider two possiblties + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if ( + (start_idx < pid_c * chunk_size) # first chunk + or + ( + HAS_INITSTATES + ) + ): + + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + + # - this seems repetitve, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load( + dA_cumsum_ptr + (start_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) + + past_states = tl.load(past_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale states = acc.to(states_ptr.dtype.element_ty) @@ -636,7 +680,7 @@ def _chunk_state_fwd(B, return states -def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): +def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None): total_seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape _, ngroups, dstate = B.shape @@ -647,6 +691,10 @@ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + states = torch.empty(batch, nheads, headdim, @@ -664,6 +712,7 @@ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): chunk_states, cu_seqlens, states, + initial_states, headdim, dstate, chunk_size, @@ -689,5 +738,12 @@ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): states.stride(1), states.stride(2), states.stride(3), + *( + ( + initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), initial_states.stride(3) + ) if initial_states is not None else (0, 0, 0, 0) + ), + HAS_INITSTATES=initial_states is not None ) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 9b5e18368530d..361190a6ed409 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -95,9 +95,11 @@ def _mamba_chunk_scan_combined_fwd(x, # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states # ii) seq_idx and iii) has_cu_seqlens to be all specified. - # - When a new seq_idx is detected, we will load the correct initial_state - # and ensure that the output states is correctly updated. - # + # - When a new seq_idx is detected, we will stopp passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - this will ensure that states will be updated with the righmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], @@ -119,12 +121,14 @@ def _mamba_chunk_scan_combined_fwd(x, # 5. Scan and compute the diagonal blocks, taking into # account past causal states. - # - NOTE: in addition to the logic in _state_passing_fwd to handle - # chunked prefill, we also need to modify _chunk_scan_fwd to - # - the updates to _state_passing_fwd only handles initial_state - # if the sequences are synced to the chunk boundaries. - # - but in the case where there are offsets from the chunk boundaries - # we need to further update _chunk_scan_fwd (not yet done). + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. out, out_x = _chunk_scan_fwd( CB, x, @@ -134,15 +138,18 @@ def _mamba_chunk_scan_combined_fwd(x, states, D=D, z=z, - seq_idx=(None if cu_seqlens is not None and initial_states is not None - else seq_idx)) + seq_idx=seq_idx, + initial_states=initial_states, + ) if cu_seqlens is None: return out, out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0), - cu_seqlens, states.squeeze(0)) + cu_seqlens, states.squeeze(0), + initial_states=initial_states, + ) return out, out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 5b44ce07a4b85..c4e6cd2f961f4 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -79,12 +79,17 @@ def _state_passing_fwd_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + # - states will be the past state of the sequence that continues on the current check if not HAS_INITSTATES: states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) else: - initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk seq_idx = 0 @@ -94,25 +99,24 @@ def _state_passing_fwd_kernel( dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale = tl.exp(dA_cs) if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_new below. seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) if HAS_INITSTATES: if IS_CONT_BATCHED and seq_idx != seq_idx_new: - # need to load the initial state for this new sequence - # - override the scanned state - initstates_ptrs += seq_idx_new * stride_initstates_batch + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - - # in the previous scan iteration, the wrong state was - # written to the output buffer - # - so we also override it - tl.store(out_ptrs - stride_out_chunk, - states, - mask=offs_m < dim) else: scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)