From 60dadf2e0ee730ac337035d5533de10bc26e4847 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 2 Jun 2024 23:26:32 -0700 Subject: [PATCH] Mamba-2 code release --- README.md | 58 +- .../benchmark_generation_mamba_simple.py | 2 +- mamba_ssm/__init__.py | 3 +- mamba_ssm/distributed/__init__.py | 0 mamba_ssm/distributed/distributed_utils.py | 144 ++ mamba_ssm/distributed/tensor_parallel.py | 296 +++ mamba_ssm/models/config_mamba.py | 3 + mamba_ssm/models/mixer_seq_simple.py | 65 +- mamba_ssm/modules/block.py | 91 + mamba_ssm/modules/mamba2.py | 358 ++++ mamba_ssm/modules/mamba2_simple.py | 199 ++ mamba_ssm/modules/mamba_simple.py | 61 +- mamba_ssm/modules/mha.py | 289 +++ mamba_ssm/modules/mlp.py | 34 + .../triton/{layernorm.py => layer_norm.py} | 525 ++++- mamba_ssm/ops/triton/layernorm_gated.py | 437 ++++ mamba_ssm/ops/triton/ssd_bmm.py | 262 +++ mamba_ssm/ops/triton/ssd_chunk_scan.py | 1825 +++++++++++++++++ mamba_ssm/ops/triton/ssd_chunk_state.py | 866 ++++++++ mamba_ssm/ops/triton/ssd_combined.py | 959 +++++++++ mamba_ssm/ops/triton/ssd_state_passing.py | 348 ++++ 21 files changed, 6707 insertions(+), 118 deletions(-) create mode 100644 mamba_ssm/distributed/__init__.py create mode 100644 mamba_ssm/distributed/distributed_utils.py create mode 100644 mamba_ssm/distributed/tensor_parallel.py create mode 100644 mamba_ssm/modules/block.py create mode 100644 mamba_ssm/modules/mamba2.py create mode 100644 mamba_ssm/modules/mamba2_simple.py create mode 100644 mamba_ssm/modules/mha.py create mode 100644 mamba_ssm/modules/mlp.py rename mamba_ssm/ops/triton/{layernorm.py => layer_norm.py} (54%) create mode 100644 mamba_ssm/ops/triton/layernorm_gated.py create mode 100644 mamba_ssm/ops/triton/ssd_bmm.py create mode 100644 mamba_ssm/ops/triton/ssd_chunk_scan.py create mode 100644 mamba_ssm/ops/triton/ssd_chunk_state.py create mode 100644 mamba_ssm/ops/triton/ssd_combined.py create mode 100644 mamba_ssm/ops/triton/ssd_state_passing.py diff --git a/README.md b/README.md index 47a26edc..6ca031cd 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,9 @@ > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ > Albert Gu*, Tri Dao*\ > Paper: https://arxiv.org/abs/2312.00752 +> **Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\ +> Tri Dao*, Albert Gu*\ +> Paper: https://arxiv.org/abs/2405.21060 ## About @@ -43,7 +46,7 @@ The main module of this repository is the Mamba architecture block wrapping the Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). Usage: -``` +``` python import torch from mamba_ssm import Mamba @@ -60,6 +63,24 @@ y = model(x) assert y.shape == x.shape ``` +The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py). + +A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py) + +The usage is similar to Mamba(-1): +``` python +from mamba_ssm import Mamba2 +model = Mamba( + # This module uses roughly 3 * expand * d_model^2 parameters + d_model=dim, # Model dimension d_model + d_state=64, # SSM state expansion factor, typically 64 or 128 + d_conv=4, # Local convolution width + expand=2, # Block expansion factor +).to("cuda") +y = model(x) +assert y.shape == x.shape +``` + ### Mamba Language Model Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. @@ -70,12 +91,12 @@ This is an example of how to integrate Mamba into an end-to-end neural network. This example is used in the generation scripts below. - ## Pretrained Models Pretrained models are uploaded to [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, -`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj` +`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`, +`mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj` (trained on 600B tokens on the SlimPajama dataset). @@ -106,17 +127,24 @@ library. 1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`. 2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): -``` +``` sh lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 ``` To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts: -``` +``` sh lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256 lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256 ``` +To run evaluations on Mamba-2 models, simply replace the model names: +``` sh +lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 +lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 +lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 +``` + Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. ## Inference @@ -132,16 +160,21 @@ Other configurable options include the top-p (nucleus sampling) probability, and To test generation latency (e.g. batch size = 1) with different sampling strategies: -``` +``` sh python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2 ``` To test generation throughput with random prompts (e.g. large batch size): +``` sh +python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64 +python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64 ``` -python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 -python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128 + +With Mamba-2, you just need to change the model name: +``` sh +python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 ``` @@ -164,7 +197,7 @@ that is specific to the training framework. ## Citation -If you use this codebase, or otherwise found our work valuable, please cite Mamba: +If you use this codebase, or otherwise find our work valuable, please cite Mamba: ``` @article{mamba, title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, @@ -172,4 +205,11 @@ If you use this codebase, or otherwise found our work valuable, please cite Mamb journal={arXiv preprint arXiv:2312.00752}, year={2023} } +@inproceedings{mamba2, + title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality}, + author={Dao, Tri and Gu, Albert}, + booktitle={International Conference on Machine Learning (ICML)}, + year={2024} +} + ``` diff --git a/benchmarks/benchmark_generation_mamba_simple.py b/benchmarks/benchmark_generation_mamba_simple.py index b7607787..f3513b24 100644 --- a/benchmarks/benchmark_generation_mamba_simple.py +++ b/benchmarks/benchmark_generation_mamba_simple.py @@ -32,7 +32,7 @@ dtype = torch.float16 print(f"Loading model {args.model_name}") -is_mamba = args.model_name.startswith("state-spaces/mamba-") +is_mamba = args.model_name.startswith("state-spaces/mamba") or args.model_name.startswith("state-spaces/transformerpp") if is_mamba: tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index c2a23000..64f4c0c4 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,5 +1,6 @@ -__version__ = "1.2.2" +__version__ = "2.0.0" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.modules.mamba2 import Mamba2 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel diff --git a/mamba_ssm/distributed/__init__.py b/mamba_ssm/distributed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mamba_ssm/distributed/distributed_utils.py b/mamba_ssm/distributed/distributed_utils.py new file mode 100644 index 00000000..74c55279 --- /dev/null +++ b/mamba_ssm/distributed/distributed_utils.py @@ -0,0 +1,144 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.distributed import ProcessGroup + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 4 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base +if "reduce_scatter_tensor" not in dir(torch.distributed): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + + +# Raw operation, does not support autograd, but does support async +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + output = torch.empty( + world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device + ) + handle = torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + assert input_.shape[0] % world_size == 0 + output = torch.empty( + input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device + ) + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + input_ = input_.contiguous() + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + return input_, handle + + +class AllGatherFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_gather_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +all_gather = AllGatherFunc.apply + + +class ReduceScatterFunc(torch.autograd.Function): + """Reduce scatter the input from the sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = reduce_scatter_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = all_gather_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +reduce_scatter = ReduceScatterFunc.apply + + +class AllReduceFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_reduce_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + return grad_output, None + + +# Supports autograd, but does not support async +all_reduce = AllReduceFunc.apply + + +def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _shared_params=True in the same order, + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). + pamams_shared = { + name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) + } + for _, p in sorted(pamams_shared.items()): + with torch.no_grad(): + # Broadcast needs src to be global rank, not group rank + torch.distributed.broadcast( + p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group + ) + + +# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 +def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _sequence_parallel=True in the same order, + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). + params_seqparallel = { + name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) + } + grads = [p.grad for _, p in sorted(params_seqparallel.items())] + if grads: + with torch.no_grad(): + coalesced = torch._utils._flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=process_group) + for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + +def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: + """Get the dim for the local rank derived from splitting dim on world_size processes. + + The split may not be even across the world_size processes. + """ + multiple = dim // multiple_of + div = multiple // world_size + mod = multiple % world_size + local_multiple = div + int(local_rank < mod) + return local_multiple * multiple_of diff --git a/mamba_ssm/distributed/tensor_parallel.py b/mamba_ssm/distributed/tensor_parallel.py new file mode 100644 index 00000000..cc55e793 --- /dev/null +++ b/mamba_ssm/distributed/tensor_parallel.py @@ -0,0 +1,296 @@ +# Copyright (c) 2024, Tri Dao. +# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.distributed import ProcessGroup + +from einops import rearrange + +from src.distributed.distributed_utils import ( + all_gather_raw, + all_reduce, + all_reduce_raw, + reduce_scatter, + reduce_scatter_raw, +) + + +class ParallelLinearFunc(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. + """ + ctx.compute_weight_gradient = weight.requires_grad + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + output = F.linear(total_x, weight, bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + else: + (weight,) = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + grad_input = F.linear(grad_output, weight.t()) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight = torch.einsum( + "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1]) + ) + else: + grad_weight = None + grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None + + +def parallel_linear_func( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel) + + +class ColumnParallelLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + multiple_of=1, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + if out_features % multiple_of: + raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") + multiple = out_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) + super().__init__( + in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + return parallel_linear_func( + x, + self.weight, + self.bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + + +class RowParallelLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + multiple_of=1, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + rank = torch.distributed.get_rank(process_group) + if in_features % multiple_of: + raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") + multiple = in_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) + # Only rank 0 will have bias + super().__init__( + local_multiple * multiple_of, + out_features, + bias=bias and rank == 0, + device=device, + dtype=dtype, + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = parallel_linear_func(x, self.weight, self.bias) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) + + +class VocabParallelEmbedding(nn.Embedding): + def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if num_embeddings % world_size != 0: + raise ValueError( + f"num_embeddings ({num_embeddings}) must be divisible by " + f"world_size ({world_size})" + ) + if world_size > 1 and padding_idx is not None: + raise RuntimeError("ParallelEmbedding does not support padding_idx") + else: + world_size = 1 + super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) + + def forward(self, input: Tensor) -> Tensor: + if self.process_group is None: + return super().forward(input) + else: + rank = torch.distributed.get_rank(self.process_group) + vocab_size = self.num_embeddings + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + # Create a mask of valid vocab ids (1 means it needs to be masked). + input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) + input = input - vocab_start_index + input[input_ids_mask] = 0 + embeddings = super().forward(input) + embeddings[input_ids_mask] = 0.0 + return embeddings + + +class ColumnParallelEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if embedding_dim % world_size != 0: + raise ValueError( + f"embedding_dim ({embedding_dim}) must be divisible by " + f"world_size ({world_size})" + ) + else: + world_size = 1 + super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) + + +class ParallelEmbeddings(nn.Module): + def __init__( + self, + embed_dim, + vocab_size, + max_position_embeddings, + process_group, + padding_idx=None, + sequence_parallel=True, + device=None, + dtype=None, + ): + """ + If max_position_embeddings <= 0, there's no position embeddings + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.word_embeddings = VocabParallelEmbedding( + vocab_size, + embed_dim, + padding_idx=padding_idx, + process_group=process_group, + **factory_kwargs, + ) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = ColumnParallelEmbedding( + max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs + ) + + def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + world_size = torch.distributed.get_world_size(self.process_group) + embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + if world_size <= 1: + embeddings = embeddings + position_embeddings + else: + partition_dim = self.position_embeddings.embedding_dim + rank = torch.distributed.get_rank(self.process_group) + embeddings[ + ..., rank * partition_dim : (rank + 1) * partition_dim + ] += position_embeddings + if combine_batch_seqlen_dim: + embeddings = rearrange(embeddings, "b s d -> (b s) d") + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) diff --git a/mamba_ssm/models/config_mamba.py b/mamba_ssm/models/config_mamba.py index 2aa1e5a6..646c9e1e 100644 --- a/mamba_ssm/models/config_mamba.py +++ b/mamba_ssm/models/config_mamba.py @@ -5,9 +5,12 @@ class MambaConfig: d_model: int = 2560 + d_intermediate: int = 0 n_layer: int = 64 vocab_size: int = 50277 ssm_cfg: dict = field(default_factory=dict) + attn_layer_idx: list = field(default_factory=list) + attn_cfg: dict = field(default_factory=dict) rms_norm: bool = True residual_in_fp32: bool = True fused_add_norm: bool = True diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index cd224738..4be57e08 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -4,6 +4,7 @@ from functools import partial import json import os +import copy from collections import namedtuple @@ -11,19 +12,26 @@ import torch.nn as nn from mamba_ssm.models.config_mamba import MambaConfig -from mamba_ssm.modules.mamba_simple import Mamba, Block +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.modules.mamba2 import Mamba2 +from mamba_ssm.modules.mha import MHA +from mamba_ssm.modules.mlp import GatedMLP +from mamba_ssm.modules.block import Block from mamba_ssm.utils.generation import GenerationMixin from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn + from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None def create_block( d_model, + d_intermediate, ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, norm_epsilon=1e-5, rms_norm=False, residual_in_fp32=False, @@ -34,14 +42,38 @@ def create_block( ): if ssm_cfg is None: ssm_cfg = {} + if attn_layer_idx is None: + attn_layer_idx = [] + if attn_cfg is None: + attn_cfg = {} factory_kwargs = {"device": device, "dtype": dtype} - mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + if layer_idx not in attn_layer_idx: + # Create a copy of the config to modify + ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {} + ssm_layer = ssm_cfg.pop("layer", "Mamba1") + if ssm_layer not in ["Mamba1", "Mamba2"]: + raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2") + mixer_cls = partial( + Mamba2 if ssm_layer == "Mamba2" else Mamba, + layer_idx=layer_idx, + **ssm_cfg, + **factory_kwargs + ) + else: + mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs ) + if d_intermediate == 0: + mlp_cls = nn.Identity + else: + mlp_cls = partial( + GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs + ) block = Block( d_model, mixer_cls, + mlp_cls, norm_cls=norm_cls, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, @@ -88,8 +120,11 @@ def __init__( self, d_model: int, n_layer: int, + d_intermediate: int, vocab_size: int, ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, norm_epsilon: float = 1e-5, rms_norm: bool = False, initializer_cfg=None, @@ -118,7 +153,10 @@ def __init__( [ create_block( d_model, + d_intermediate=d_intermediate, ssm_cfg=ssm_cfg, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, norm_epsilon=norm_epsilon, rms_norm=rms_norm, residual_in_fp32=residual_in_fp32, @@ -139,6 +177,7 @@ def __init__( _init_weights, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}), + n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP ) ) @@ -148,7 +187,7 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) for i, layer in enumerate(self.layers) } - def forward(self, input_ids, inference_params=None): + def forward(self, input_ids, inference_params=None, **mixer_kwargs): hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: @@ -160,8 +199,7 @@ def forward(self, input_ids, inference_params=None): hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) else: # Set prenorm=False here since we don't need the residual - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn - hidden_states = fused_add_norm_fn( + hidden_states = layer_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, @@ -169,6 +207,7 @@ def forward(self, input_ids, inference_params=None): residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm_f, RMSNorm) ) return hidden_states @@ -185,8 +224,11 @@ def __init__( self.config = config d_model = config.d_model n_layer = config.n_layer + d_intermediate = config.d_intermediate vocab_size = config.vocab_size ssm_cfg = config.ssm_cfg + attn_layer_idx = config.attn_layer_idx + attn_cfg = config.attn_cfg rms_norm = config.rms_norm residual_in_fp32 = config.residual_in_fp32 fused_add_norm = config.fused_add_norm @@ -199,8 +241,11 @@ def __init__( self.backbone = MixerModel( d_model=d_model, n_layer=n_layer, + d_intermediate=d_intermediate, vocab_size=vocab_size, ssm_cfg=ssm_cfg, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, rms_norm=rms_norm, initializer_cfg=initializer_cfg, fused_add_norm=fused_add_norm, @@ -222,16 +267,16 @@ def __init__( def tie_weights(self): if self.config.tie_embeddings: self.lm_head.weight = self.backbone.embedding.weight - + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens """ - hidden_states = self.backbone(input_ids, inference_params=inference_params) + hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs) if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) @@ -261,4 +306,4 @@ def save_pretrained(self, save_directory): # Save the configuration of the model config_path = os.path.join(save_directory, 'config.json') with open(config_path, 'w') as f: - json.dump(self.config.__dict__, f) + json.dump(self.config.__dict__, f, indent=4) diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py new file mode 100644 index 00000000..b0ed44e1 --- /dev/null +++ b/mamba_ssm/modules/block.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +from typing import Optional + +import torch +from torch import nn, Tensor + +from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn + + +class Block(nn.Module): + def __init__( + self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.norm = norm_cls(dim) + self.mixer = mixer_cls(dim) + if mlp_cls is not nn.Identity: + self.norm2 = norm_cls(dim) + self.mlp = mlp_cls(dim) + else: + self.mlp = None + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + """ + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + is_rms_norm=isinstance(self.norm, RMSNorm) + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs) + + if self.mlp is not None: + if not self.fused_add_norm: + residual = hidden_states + residual + residual = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm2.weight, + self.norm2.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm2.eps, + is_rms_norm=isinstance(self.norm2, RMSNorm) + ) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py new file mode 100644 index 00000000..e60f987d --- /dev/null +++ b/mamba_ssm/modules/mamba2.py @@ -0,0 +1,358 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None, None + +try: + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + +from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear +from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter + +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined + + +class Mamba2(nn.Module): + def __init__( + self, + d_model, + d_state=128, + d_conv=4, + conv_init=None, + expand=2, + headdim=64, + d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP + ngroups=1, + A_init_range=(1, 16), + D_has_hdim=False, + rmsnorm=True, + norm_before_gate=False, + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + dt_limit=(0.0, float("inf")), + bias=False, + conv_bias=True, + # Fused kernel and sharding options + chunk_size=256, + use_mem_eff_path=True, + layer_idx=None, # Absorb kwarg for general module + process_group=None, + sequence_parallel=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.conv_init = conv_init + self.expand = expand + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.world_size = 1 if process_group is None else process_group.size() + self.local_rank = 0 if process_group is None else process_group.rank() + self.d_inner = (self.expand * self.d_model) // self.world_size + assert self.d_inner * self.world_size == self.expand * self.d_model + self.headdim = headdim + self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size + assert ngroups % self.world_size == 0 + self.ngroups = ngroups // self.world_size + assert self.d_ssm % self.headdim == 0 + self.nheads = self.d_ssm // self.headdim + self.D_has_hdim = D_has_hdim + self.rmsnorm = rmsnorm + self.norm_before_gate = norm_before_gate + self.dt_limit = dt_limit + self.activation = "silu" + self.chunk_size = chunk_size + self.use_mem_eff_path = use_mem_eff_path + self.layer_idx = layer_idx + + # Order: [z, x, B, C, dt] + d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + if self.process_group is None: + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) + else: + self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias, + process_group=self.process_group, sequence_parallel=self.sequence_parallel, + **factory_kwargs) + + conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + if self.conv_init is not None: + nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) + + self.act = nn.SiLU() + + # Initialize log dt bias + dt = torch.exp( + torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) + A_log = torch.log(A).to(dtype=dtype) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) + self.D._no_weight_decay = True + + if self.rmsnorm: + assert RMSNormGated is not None + self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate, + group_size=self.d_ssm // ngroups, **factory_kwargs) + + if self.process_group is None: + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + else: + self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias, + process_group=self.process_group, sequence_parallel=self.sequence_parallel, + **factory_kwargs) + + def forward(self, u, seqlen=None, seq_idx=None, inference_params=None): + """ + u: (batch, seqlen, hidden_dim) if seqlen=None. + If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we + split u during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + Returns: same shape as u + """ + seqlen_og = seqlen + if seqlen is None: + batch, seqlen, dim = u.shape + else: + batch_seqlen, dim = u.shape + batch = batch_seqlen // seqlen + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(u, conv_state, ssm_state) + return out + + zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj) + if seqlen_og is not None: + zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) + A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state) + dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) + if self.use_mem_eff_path and inference_params is None: + out = mamba_split_conv1d_scan_combined( + zxbcdt, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.dt_bias, + A, + D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, + chunk_size=self.chunk_size, + seq_idx=seq_idx, + activation=self.activation, + rmsnorm_weight=self.norm.weight if self.rmsnorm else None, + rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=None if self.D_has_hdim else self.headdim, + ngroups=self.ngroups, + norm_before_gate=self.norm_before_gate, + **dt_limit_kwargs, + ) + if seqlen_og is not None: + out = rearrange(out, "b l d -> (b l) d") + if self.process_group is not None: + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + out = reduce_fn(out, self.process_group) + else: + d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 + z0, x0, z, xBC, dt = torch.split( + zxbcdt, + [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], + dim=-1 + ) + if conv_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC, "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + assert self.activation in ["silu", "swish"] + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + xBC = self.act( + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2) + ) # (B, L, self.d_ssm + 2 * ngroups * d_state) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) + y = mamba_chunk_scan_combined( + rearrange(x, "b l (h p) -> b l h p", p=self.headdim), + dt, + A, + rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), + rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), + chunk_size=self.chunk_size, + D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, + z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, + dt_bias=self.dt_bias, + dt_softplus=True, + seq_idx=seq_idx, + **dt_limit_kwargs, + return_final_states=ssm_state is not None, + ) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, "b l h p -> b l (h p)") + if self.rmsnorm: + y = self.norm(y, z) + if d_mlp > 0: + y = torch.cat([F.silu(z0) * x0, y], dim=-1) + if seqlen_og is not None: + y = rearrange(y, "b l d -> (b l) d") + out = self.out_proj(y) + return out + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 + z0, x0, z, xBC, dt = torch.split( + zxbcdt, + [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], + dim=-1 + ) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(dtype=dtype) + else: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) + A = -torch.exp(self.A_log.float()) # (nheads,) + + # SSM step + if selective_state_update is None: + assert self.ngroups == 1, "Only support ngroups=1 for this inference code path" + # Discretize A and B + dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) + dA = torch.exp(dt * A) # (batch, nheads) + x = rearrange(x, "b (h p) -> b h p", p=self.headdim) + dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) + ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) + y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) + y = y + rearrange(self.D.to(dtype), "h -> h 1") * x + y = rearrange(y, "b h p -> b (h p)") + if not self.rmsnorm: + y = y * self.act(z) # (B D) + else: + A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) + dt = repeat(dt, "b h -> b h p", p=self.headdim) + dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) + D = repeat(self.D, "h -> h p", p=self.headdim) + B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) + C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) + x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + y = selective_state_update( + ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None, + dt_bias=dt_bias, dt_softplus=True + ) + y = rearrange(y, "b h p -> b (h p)") + if self.rmsnorm: + y = self.norm(y, z) + if d_mlp > 0: + y = torch.cat([F.silu(z0) * x0, y], dim=-1) + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype + ) + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (batch_size,) + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.nheads, + self.headdim, + self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py new file mode 100644 index 00000000..026c674b --- /dev/null +++ b/mamba_ssm/modules/mamba2_simple.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn +except ImportError: + causal_conv1d_fn = None + +try: + from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm +except ImportError: + RMSNormGated, LayerNorm = None, None + +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined + + +class Mamba2Simple(nn.Module): + def __init__( + self, + d_model, + d_state=64, + d_conv=4, + conv_init=None, + expand=2, + headdim=128, + ngroups=1, + A_init_range=(1, 16), + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + dt_limit=(0.0, float("inf")), + learnable_init_states=False, + activation="swish", + bias=False, + conv_bias=True, + # Fused kernel and sharding options + chunk_size=256, + use_mem_eff_path=True, + layer_idx=None, # Absorb kwarg for general module + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.conv_init = conv_init + self.expand = expand + self.d_inner = self.expand * self.d_model + self.headdim = headdim + self.ngroups = ngroups + assert self.d_inner % self.headdim == 0 + self.nheads = self.d_inner // self.headdim + self.dt_limit = dt_limit + self.learnable_init_states = learnable_init_states + self.activation = activation + self.chunk_size = chunk_size + self.use_mem_eff_path = use_mem_eff_path + self.layer_idx = layer_idx + + # Order: [z, x, B, C, dt] + d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) + + conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + if self.conv_init is not None: + nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) + # self.conv1d.weight._no_weight_decay = True + + if self.learnable_init_states: + self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)) + self.init_states._no_weight_decay = True + + self.act = nn.SiLU() + + # Initialize log dt bias + dt = torch.exp( + torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + # A parameter + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) + A_log = torch.log(A).to(dtype=dtype) + self.A_log = nn.Parameter(A_log) + # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.nheads, device=device)) + self.D._no_weight_decay = True + + # Extra normalization layer right before output projection + assert RMSNormGated is not None + self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs) + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + def forward(self, u, seq_idx=None): + """ + u: (B, L, D) + Returns: same shape as u + """ + batch, seqlen, dim = u.shape + + zxbcdt = self.in_proj(u) # (B, L, d_in_proj) + A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state) + initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None + dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) + + if self.use_mem_eff_path: + # Fully fused path + out = mamba_split_conv1d_scan_combined( + zxbcdt, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=seq_idx, + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.eps, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.headdim, + ngroups=self.ngroups, + norm_before_gate=False, + initial_states=initial_states, + **dt_limit_kwargs, + ) + else: + z, xBC, dt = torch.split( + zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1 + ) + dt = F.softplus(dt + self.dt_bias) # (B, L, nheads) + assert self.activation in ["silu", "swish"] + + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + xBC = self.act( + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2) + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + xBC = causal_conv1d_fn( + x=xBC.transpose(1, 2), + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + + # Split into 3 main branches: X, B, C + # These correspond to V, K, Q respectively in the SSM/attention duality + x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) + y = mamba_chunk_scan_combined( + rearrange(x, "b l (h p) -> b l h p", p=self.headdim), + dt, + A, + rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), + rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=seq_idx, + initial_states=initial_states, + **dt_limit_kwargs, + ) + y = rearrange(y, "b l h p -> b l (h p)") + + # Multiply "gate" branch and apply extra normalization layer + y = self.norm(y, z) + out = self.out_proj(y) + return out diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 91cb9798..4c8a3882 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -23,7 +23,7 @@ selective_state_update = None try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn + from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None @@ -292,62 +292,3 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/mamba_ssm/modules/mha.py b/mamba_ssm/modules/mha.py new file mode 100644 index 00000000..ae3b099e --- /dev/null +++ b/mamba_ssm/modules/mha.py @@ -0,0 +1,289 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +try: + from flash_attn import flash_attn_with_kvcache +except ImportError: + flash_attn_with_kvcache = None + +try: + from flash_attn.layers.rotary import RotaryEmbedding +except ImportError: + RotaryEmbedding = None + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None, None + + +def _update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + assert layer_idx in inference_params.key_value_memory_dict + kv_cache, _ = inference_params.key_value_memory_dict[layer_idx] + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.seqlen_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= kv_cache.shape[0] + assert sequence_end <= kv_cache.shape[1] + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + return kv_cache[batch_start:batch_end, :sequence_end, ...] + + +class MHA(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + embed_dim, + num_heads, + num_heads_kv=None, + head_dim=None, # If None, use embed_dim // num_heads + mlp_dim=0, + qkv_proj_bias=True, + out_proj_bias=True, + softmax_scale=None, + causal=False, + layer_idx=None, + d_conv=0, + rotary_emb_dim=0, + rotary_emb_base=10000.0, + rotary_emb_interleaved=False, + device=None, + dtype=None, + ) -> None: + """ + num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.layer_idx = layer_idx + self.d_conv = d_conv + self.rotary_emb_dim = rotary_emb_dim + self.softmax_scale = softmax_scale + self.causal = causal + + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert ( + self.num_heads % self.num_heads_kv == 0 + ), "num_heads must be divisible by num_heads_kv" + if head_dim is None: + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads + self.mlp_dim = math.ceil(mlp_dim / 256) * 256 + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + out_dim = self.head_dim * self.num_heads + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, + base=rotary_emb_base, + interleaved=rotary_emb_interleaved, + device=device, + ) + + self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs) + if self.d_conv > 0: + self.conv1d = nn.Conv1d( + qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim, + **factory_kwargs + ) + self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + if self.d_conv > 0: + conv_state = torch.zeros( + batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype + ) + else: + conv_state = None + kv_cache = torch.empty( + batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device, + ) + return kv_cache, conv_state + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): + """ + Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. + q: (batch_size, seqlen_q, nheads, head_dim) + kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) + """ + assert inference_params is not None and inference_params.seqlen_offset > 0 + if self.rotary_emb_dim > 0: + self.rotary_emb._update_cos_sin_cache( + inference_params.max_seqlen, device=q.device, dtype=q.dtype + ) + rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached + else: + rotary_cos, rotary_sin = None, None + batch = q.shape[0] + kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx] + kv_cache = kv_cache[:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + assert flash_attn_with_kvcache is not None, "flash_attn must be installed" + context = flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + softmax_scale=self.softmax_scale, + causal=self.causal, + rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + ) + return context + + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then do attention""" + if ( + inference_params.seqlen_offset == 0 + or flash_attn_with_kvcache is None + ): + # TODO: this only uses seqlen_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + k, v = kv.unbind(dim=-3) + return F.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale + ).transpose(1, 2) + else: + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + return flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.softmax_scale, + causal=self.causal, + ) + + def forward(self, x, inference_params=None): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if + cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total + is the is the sum of the sequence lengths in the batch. + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + """ + if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + x.shape[0], inference_params.max_seqlen, dtype=x.dtype + ) + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None + qkv = self.in_proj(x) + if self.mlp_dim > 0: + qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1) + x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1) + x_mlp = x_mlp_up * F.silu(x_mlp_gate) + if self.d_conv > 0: + # The inference code for conv1d is pretty messy, should clean it up + if (inference_params is None or inference_params.seqlen_offset == 0): + if causal_conv1d_fn is None: + qkv = rearrange( + self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d" + ).contiguous() + else: + qkv = causal_conv1d_fn( + qkv.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias + ).transpose(1, 2) + if inference_params is not None: + _, conv_state = inference_params.key_value_memory_dict[self.layer_idx] + # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + qkv_t = rearrange(qkv, "b l d -> b d l") + conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W) + else: + _, conv_state = inference_params.key_value_memory_dict[self.layer_idx] + assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now" + qkv = qkv.squeeze(1) + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = qkv + qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + qkv = qkv + self.conv1d.bias + else: + qkv = causal_conv1d_update( + qkv, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias + ) + qkv = qkv.unsqueeze(1) + q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1) + q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) + kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + ): + if self.rotary_emb_dim > 0: + q, kv = self.rotary_emb( + q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + k, v = kv.unbind(dim=-3) + context = F.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale + ).transpose(1, 2) + else: + context = self._update_kvcache_attention(q, kv, inference_params) + else: + context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) + context = rearrange(context, "... h d -> ... (h d)") + if self.mlp_dim > 0: + context = torch.cat([context, x_mlp], dim=-1) + out = self.out_proj(context) + return out diff --git a/mamba_ssm/modules/mlp.py b/mamba_ssm/modules/mlp.py new file mode 100644 index 00000000..33bab5c7 --- /dev/null +++ b/mamba_ssm/modules/mlp.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +from torch import nn +from torch.nn import functional as F + + +class GatedMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation=F.silu, + bias=False, + multiple_of=128, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features if out_features is not None else in_features + hidden_features = ( + hidden_features if hidden_features is not None else int(8 * in_features / 3) + ) + hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of + self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + y, gate = y.chunk(2, dim=-1) + y = y * self.activation(gate) + y = self.fc2(y) + return y diff --git a/mamba_ssm/ops/triton/layernorm.py b/mamba_ssm/ops/triton/layer_norm.py similarity index 54% rename from mamba_ssm/ops/triton/layernorm.py rename to mamba_ssm/ops/triton/layer_norm.py index ba33ce1e..6fcf50e1 100644 --- a/mamba_ssm/ops/triton/layernorm.py +++ b/mamba_ssm/ops/triton/layer_norm.py @@ -1,5 +1,5 @@ -# Copyright (c) 2023, Tri Dao. -# Implement residual + layer_norm / rms_norm. +# Copyright (c) 2024, Tri Dao. +# Implement dropout + residual + layer_norm / rms_norm. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. @@ -16,36 +16,113 @@ import triton.language as tl -def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): +def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): dtype = x.dtype if upcast: + x = x.float() weight = weight.float() bias = bias.float() if bias is not None else None - if upcast: - x = x.float() residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 if residual is not None: x = (x + residual).to(x.dtype) out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( dtype ) - return out if not prenorm else (out, x) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm( + x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) -def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): +def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): dtype = x.dtype if upcast: + x = x.float() weight = weight.float() bias = bias.float() if bias is not None else None - if upcast: - x = x.float() residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 if residual is not None: x = (x + residual).to(x.dtype) rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) - out = out.to(dtype) - return out if not prenorm else (out, x) + out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( + dtype + ) + return (out, out1) if not prenorm else (out, out1, x) @triton.autotune( @@ -61,6 +138,9 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input @@ -68,20 +148,37 @@ def _layer_norm_fwd_1pass_kernel( W, # pointer to the weights B, # pointer to the biases RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_res_row, stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) @@ -91,9 +188,38 @@ def _layer_norm_fwd_1pass_kernel( RESIDUAL += row * stride_res_row if STORE_RESIDUAL_OUT: RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + x += x1 if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) x += residual @@ -118,10 +244,29 @@ def _layer_norm_fwd_1pass_kernel( y = x_hat * w + b if HAS_BIAS else x_hat * w # Write output tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) def _layer_norm_fwd( - x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + is_rms_norm=False, + return_dropout_mask=False, ): if residual is not None: residual_dtype = residual.dtype @@ -135,22 +280,57 @@ def _layer_norm_fwd( if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) # allocate output y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) assert y.stride(-1) == 1 - if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): - residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + if weight1 is not None: + y1 = torch.empty_like(y) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) assert residual_out.stride(-1) == 1 else: residual_out = None mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask = None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps with torch.cuda.device(x.device.index): _layer_norm_fwd_1pass_kernel[(M,)]( x, @@ -158,23 +338,50 @@ def _layer_norm_fwd( weight, bias, residual, + x1, + weight1, + bias1, + y1, residual_out, + rowscale, + seeds, + dropout_mask, mean, rstd, x.stride(0), y.stride(0), residual.stride(0) if residual is not None else 0, residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, N, eps, + dropout_p, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, ) - # residual_out is None if residual is None and residual_dtype == input_dtype - return y, mean, rstd, residual_out if residual_out is not None else x + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + y, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) @triton.autotune( @@ -186,11 +393,15 @@ def _layer_norm_fwd( triton.Config({}, num_warps=16), triton.Config({}, num_warps=32), ], - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( @@ -203,7 +414,14 @@ def _layer_norm_bwd_kernel( DW, # pointer to the partial sum of weights gradient DB, # pointer to the partial sum of biases gradient DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, DRESIDUAL_IN, + ROWSCALE, + SEEDS, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row @@ -211,21 +429,30 @@ def _layer_norm_bwd_kernel( stride_dy_row, stride_dx_row, stride_dres_row, + stride_dy1_row, + stride_dx1_row, stride_dres_in_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero + dropout_p, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, ): # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row @@ -235,19 +462,31 @@ def _layer_norm_bwd_kernel( DRESIDUAL_IN += row_start * stride_dres_in_row DY += row_start * stride_dy_row DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row w = tl.load(W + cols, mask=mask).to(tl.float32) if RECOMPUTE_OUTPUT and HAS_BIAS: b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) row_end = min((row_block_id + 1) * rows_per_program, M) for row in range(row_start, row_end): # Load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) if not IS_RMS_NORM: mean = tl.load(Mean + row) rstd = tl.load(Rstd + row) @@ -261,6 +500,11 @@ def _layer_norm_bwd_kernel( dw += dy * xhat if HAS_BIAS: db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 if not IS_RMS_NORM: c1 = tl.sum(xhat * wdy, axis=0) / N c2 = tl.sum(wdy, axis=0) / N @@ -274,6 +518,21 @@ def _layer_norm_bwd_kernel( # Write dx if STORE_DRESIDUAL: tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale tl.store(DX + cols, dx, mask=mask) X += stride_x_row @@ -285,9 +544,17 @@ def _layer_norm_bwd_kernel( Y += stride_y_row DY += stride_dy_row DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row tl.store(DW + row_block_id * N + cols, dw, mask=mask) if HAS_BIAS: tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) def _layer_norm_bwd( @@ -299,7 +566,14 @@ def _layer_norm_bwd( mean, rstd, dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, has_residual=False, + has_x1=False, is_rms_norm=False, x_dtype=None, recompute_output=False, @@ -316,14 +590,38 @@ def _layer_norm_bwd( if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) # allocate output dx = ( torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) ) - dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + dresidual_in = ( + torch.empty_like(x) + if has_residual + and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() @@ -337,6 +635,8 @@ def _layer_norm_bwd( if bias is not None else None ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): @@ -350,7 +650,14 @@ def _layer_norm_bwd( _dw, _db, dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, dresidual_in, + rowscale, + seeds, mean, rstd, x.stride(0), @@ -358,23 +665,35 @@ def _layer_norm_bwd( dy.stride(0), dx.stride(0), dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, dresidual_in.stride(0) if dresidual_in is not None else 0, M, N, eps, + dropout_p, rows_per_program, is_rms_norm, BLOCK_N, dresidual is not None, dresidual_in is not None, bias is not None, + dropout_p > 0.0, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype: + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: dresidual_in = dx - return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) class LayerNormFn(torch.autograd.Function): @@ -385,10 +704,16 @@ def forward( weight, bias, residual=None, + x1=None, + weight1=None, + bias1=None, eps=1e-6, + dropout_p=0.0, + rowscale=None, prenorm=False, residual_in_fp32=False, is_rms_norm=False, + return_dropout_mask=False, ): x_shape_og = x.shape # reshape input data into 2D tensor @@ -400,34 +725,91 @@ def forward( residual = residual.reshape(-1, residual.shape[-1]) if residual.stride(-1) != 1: residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() weight = weight.contiguous() if bias is not None: bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() residual_dtype = ( residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) ) - y, mean, rstd, residual_out = _layer_norm_fwd( - x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd ) - ctx.save_for_backward(residual_out, weight, bias, mean, rstd) ctx.x_shape_og = x_shape_og ctx.eps = eps + ctx.dropout_p = dropout_p ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype y = y.reshape(x_shape_og) - return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) @staticmethod def backward(ctx, dy, *args): - x, weight, bias, mean, rstd = ctx.saved_tensors + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) @@ -436,7 +818,7 @@ def backward(ctx, dy, *args): assert dresidual.shape == x.shape else: dresidual = None - dx, dw, db, dresidual_in = _layer_norm_bwd( + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( dy, x, weight, @@ -445,7 +827,14 @@ def backward(ctx, dy, *args): mean, rstd, dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, ctx.has_residual, + ctx.has_x1, ctx.is_rms_norm, x_dtype=ctx.x_dtype, ) @@ -454,6 +843,12 @@ def backward(ctx, dy, *args): dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, None, None, None, @@ -466,23 +861,78 @@ def layer_norm_fn( weight, bias, residual=None, + x1=None, + weight1=None, + bias1=None, eps=1e-6, + dropout_p=0.0, + rowscale=None, prenorm=False, residual_in_fp32=False, is_rms_norm=False, + return_dropout_mask=False, ): - return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + is_rms_norm, + return_dropout_mask, + ) -def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): - return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + True, + return_dropout_mask, + ) class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() @@ -497,6 +947,7 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): self.bias, residual=residual, eps=self.eps, + dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, ) @@ -536,7 +987,7 @@ def forward( if residual is not None else (torch.float32 if residual_in_fp32 else None) ) - y, mean, rstd, residual_out = _layer_norm_fwd( + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( x, norm_weight, norm_bias, @@ -580,7 +1031,7 @@ def backward(ctx, dout, *args): assert dresidual.shape == x.shape else: dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( dy, x, norm_weight, @@ -588,9 +1039,9 @@ def backward(ctx, dout, *args): ctx.eps, mean, rstd, - dresidual, - ctx.has_residual, - ctx.is_rms_norm, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, x_dtype=ctx.x_dtype, recompute_output=True, ) diff --git a/mamba_ssm/ops/triton/layernorm_gated.py b/mamba_ssm/ops/triton/layernorm_gated.py new file mode 100644 index 00000000..de4b2f48 --- /dev/null +++ b/mamba_ssm/ops/triton/layernorm_gated.py @@ -0,0 +1,437 @@ +# Copyright (c) 2024, Tri Dao. +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange + + +def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True): + dtype = x.dtype + N = x.shape[-1] + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd, + x.stride(0), out.stride(0), z.stride(0) if z is not None else 0, + M, group_size, eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps) + return out, mean, rstd + + + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DZ, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_z_row, + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dz_row, + stride_dw_row, + stride_db_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + group = tl.program_id(1) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + group * N + if HAS_Z: + Z += row_start * stride_z_row + group * N + DZ += row_start * stride_dz_row + group * N + DY += row_start * stride_dy_row + group * N + DX += row_start * stride_dx_row + group * N + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS: + B += group * N + b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) + x_og = x + x = x_og * z * tl.sigmoid(z) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.) + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + y = xhat * w + b if HAS_BIAS else xhat * w + if RECOMPUTE_OUTPUT: + tl.store(Y + cols, y * z * z_sigmoid, mask=mask) + dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dy *= z * z_sigmoid + else: + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + c1 = tl.sum(xhat * wdy, axis=0) / N + if not IS_RMS_NORM: + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + dx = (wdy - xhat * c1) * rstd + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_Z and not NORM_BEFORE_GATE: + z_sigmoid = tl.sigmoid(z) + dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dx *= z * z_sigmoid + # Write dx + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_Z: + Z += stride_z_row + DZ += stride_dz_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask) + + +def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None, + norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = torch.empty_like(x) + if dz is not None: + assert z is not None + assert dz.shape == z.shape + assert dz.stride(-1) == 1 + else: + dz = torch.empty_like(z) if z is not None else None + if recompute_output: + if out is None: + out = torch.empty_like(x) + assert out.shape == x.shape + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs + # would limit the occupancy. + nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups) + _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device) + _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = math.ceil(M / nrow_groups) + grid = (nrow_groups, ngroups) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None, + dy, dx, _dw, _db, dz, mean, rstd, + x.stride(0), + z.stride(0) if z is not None else 0, + 0 if not recompute_output else out.stride(0), + dy.stride(0), dx.stride(0), + dz.stride(0) if dz is not None else 0, + _dw.stride(0), + _db.stride(0) if _db is not None else 0, + M, group_size, eps, + rows_per_program, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out) + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, + is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + + @staticmethod + def backward(ctx, dy): + x, weight, bias, mean, rstd, z = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size, + ctx.norm_before_gate, ctx.is_rms_norm) + return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None + + +def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm) + + +def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True) + + +class LayerNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps, + norm_before_gate=self.norm_before_gate) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size, + norm_before_gate=self.norm_before_gate) diff --git a/mamba_ssm/ops/triton/ssd_bmm.py b/mamba_ssm/ops/triton/ssd_bmm.py new file mode 100644 index 00000000..48fd4f06 --- /dev/null +++ b/mamba_ssm/ops/triton/ssd_bmm.py @@ -0,0 +1,262 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, b_ptr, out_ptr, seq_idx_ptr, + # Matrix dimensions + seqlen, chunk_size, K, ngroups, + stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk, + stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'K'], +) +@triton.jit +def _bmm_chunk_bwd_kernel( + # Pointers to matrices + a_ptr, dout_ptr, db_ptr, res_ptr, + # Matrix dimensions + seqlen, chunk_size, K, ngroups, + stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, + stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n, + stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k, + stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k, + # Meta-parameters + dot_dtype: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(K, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_cs = tl.arange(0, BLOCK_SIZE_CS) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m) + a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype) + a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype) + acc += tl.dot(dout, a) + dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m + a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_RESIDUAL: + res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head + res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k) + res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32) + acc += res + db = acc.to(db_ptr.dtype.element_ty) + + db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head + db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k) + tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)) + + +def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, dtype=out_dtype) + dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), + batch, nchunks if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, b, out, seq_idx, + seqlen, chunk_size, k, ngroups if has_groups else 1, + a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), + b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1), + out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out + + +def _bmm_chunk_bwd(a, dout, residual=None, out=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + Return: + out: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + + If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be + zeroed out before calling this function. + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + nchunks, chunk_size = dout.shape[1], dout.shape[-1] + if a.stride(-1) != 1 and a.stride(-2) != 1: + a = a.contiguous() + if dout.stride(-1) != 1 and dout.stride(-2) != 1: + dout = dout.contiguous() + if residual is not None: + assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k) + if residual.stride(-1) != 1 and residual.stride(1) != 1: + residual = residual.contiguous() + # Allocates output. + if out is not None: + assert out.shape == a.shape + assert out.stride(-1) == 1 or out.stride(1) == 1 + else: + out = torch.empty_like(a) + dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch, + nchunks if not has_groups else nchunks * ngroups) + residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2), + residual.stride(-1)) + if residual is not None else (0, 0, 0, 0)) + with torch.cuda.device(a.device.index): + _bmm_chunk_bwd_kernel[grid]( + a, dout, out, residual, + seqlen, chunk_size, k, ngroups if has_groups else 1, + a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), + dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1), + out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1), + residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3], + dot_dtype, + HAS_RESIDUAL=residual is not None, + ) + return out diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py new file mode 100644 index 00000000..ad3d5f5a --- /dev/null +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -0,0 +1,1825 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +from packaging import version + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or pid_c > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) + tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_scan_fwd_kernel_wip( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_D_head, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_n = tl.program_id(axis=0) + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE) + + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + # if pid_c == 0: + # if pid_b == 0: + # if pid_h == 0: + # tl.device_print("", prev_states) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # scale_m = tl.exp(dA_cs_m) + # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) + # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # cb *= dt_m + # mask = offs_m[:, None] >= offs_m[None, :] + # cb = tl.where(mask, cb, 0.0) + # cb = cb.to(x_ptr.dtype.element_ty) + # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0) + # acc += tl.dot(cb, x) + # if HAS_D: + # if D_HAS_HDIM: + # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + # else: + # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + # acc += x.to(tl.float32) * D + # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M): + start_m = tl.multiple_of(start_m, BLOCK_SIZE_M) + dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0) + acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) + # cb *= dt_m + # mask = offs_m[:, None] >= offs_m[None, :] + # cb = tl.where(mask, cb, 0.0) + # cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0) + # acc += tl.dot(cb, x) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + acc += x.to(tl.float32) * D + + # if HAS_Z: + # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) + # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + # acc *= z * tl.sigmoid(z) + + tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim)) + + # TODO: this is not correct, and quite a bit slower + if start_m + BLOCK_SIZE_M < chunk_size_limit: + # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32) + B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0) + dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32) + # TODO: seq_idx + scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m + # B *= scale + B = B.to(x_ptr.dtype.element_ty) + tmp = tl.dot(B, x) + prev_states += tmp.to(prev_states.dtype) + + C_ptrs += BLOCK_SIZE_M * stride_C_seqlen + B_ptrs += BLOCK_SIZE_M * stride_B_seqlen + cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_M * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_M * stride_dt_csize + out_ptrs += BLOCK_SIZE_M * stride_out_seqlen + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32}), + triton.Config({'BLOCK_SIZE_M': 64}), + triton.Config({'BLOCK_SIZE_M': 128}), + triton.Config({'BLOCK_SIZE_M': 256}), + ], + key=["chunk_size", "hdim"], +) +@triton.jit +def _chunk_scan_bwd_dz_kernel( + # Pointers to matrices + dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_D_head, + stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim, + stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim, + stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_DDACS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head + if RECOMPUTE_OUTPUT: + outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head + if HAS_DDACS: + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + if HAS_D: + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim) + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim) + dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim) + if RECOMPUTE_OUTPUT: + outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim) + if HAS_D: + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + if RECOMPUTE_OUTPUT: + outz = out * z * z_sigmoid + tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + dout *= z * z_sigmoid + tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + if HAS_D: + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + dD = tl.sum(dout * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + dD = tl.sum(dout * x) + tl.store(dD_ptr, dD) + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out -= x * D + if HAS_DDACS: + ddA_cs = tl.sum(dout * out, axis=1) + tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_scan_bwd_dstates_kernel( + # Pointers to matrices + dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + batch, seqlen, nchunks, nheads_ngroups_ratio, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate, + stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen) + c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale_k = tl.exp(dA_cs_k) + else: + seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) + scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0) + dout = (dout * scale_k).to(dout_ptr.dtype.element_ty) + c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0) + acc += tl.dot(dout, c) + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + c_ptrs += BLOCK_SIZE_K * stride_c_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + out = acc.to(dprev_states_ptr.dtype.element_ty) + + dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate) + tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_dc_kernel( + # Pointers to matrices + dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, + dc_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split + prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + if HAS_DDA_CS: + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + prev_states = prev_states.to(dout_ptrs.dtype.element_ty) + dc = tl.dot(dout, prev_states) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_m) + else: + scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + dc *= scale[:, None] + if HAS_DDA_CS: + ddA_cs = tl.sum(dc * c, axis=1) + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + acc += dc + dout_ptrs += stride_dout_head + prev_states_ptrs += stride_prev_states_head + dA_cumsum_ptrs += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptrs += stride_ddA_cs_head + # if HAS_SEQ_IDX: + # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate) + tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr, + dx_ptr, ddt_ptr, # dD_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_D_head, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + # if HAS_D: + # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Idk why limiting K_MAX gives wrong results, is it a Triton bug? + # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + K_MAX = chunk_size_limit + for k in range(0, K_MAX, BLOCK_SIZE_K): + # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + mask = k + offs_k[None, :] >= offs_m[:, None] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + # if HAS_D: + # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) + # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32) + # dD = tl.sum(x * dout, axis=0) + # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) + + +# Disabling HAS_DDA_CS for now since it's much slower +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim'], +) +# @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) +# @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32}) +@triton.jit +def _chunk_scan_bwd_dcb_kernel( + # Pointers to matrices + x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + dcb_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + if HAS_DDA_CS: + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + return + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + dcb = tl.dot(dout, x) + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + dcb *= dt_n + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) + dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + if HAS_DDA_CS: + tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") + ddA_cs = dcb * cb + mask = offs_m[:, None] >= offs_n[None, :] + 1 + ddA_cs = tl.where(mask, ddA_cs, 0.0) + ddA_cs = tl.cumsum(ddA_cs, axis=1) + ddA_cs = tl.where(mask, ddA_cs, 0.0) + ddA_cs = tl.sum(ddA_cs, axis=0) + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) + tl.store(ddA_cumsum_ptr, 0.0) + acc += dcb + dout_ptrs += stride_dout_head + x_ptrs += stride_x_head + dt_ptrs += stride_dt_head + dA_cumsum_ptr += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptr += stride_ddA_cs_head + ddA_cumsum_ptrs += stride_ddA_cs_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + mask = offs_m[:, None] >= offs_n[None, :] + acc = tl.where(mask, acc, 0.0) + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +# Not numerically stable and should not be used. Leaving here for reference. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32}), + triton.Config({'BLOCK_SIZE_M': 64}), + triton.Config({'BLOCK_SIZE_M': 128}), + triton.Config({'BLOCK_SIZE_M': 256}), + ], + key=["chunk_size", "hdim"], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_unstable_kernel( + # Pointers to matrices + dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr, + ddA_cumsum_ptr, dD_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_D_head, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + SUBTRACT_DDTDT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + if HAS_D: + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + if HAS_D: + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + dD = tl.sum(dout * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + dD = tl.sum(dout * x) + tl.store(dD_ptr, dD) + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out -= x * D + ddA_cs = tl.sum(dout * out, axis=1) + if SUBTRACT_DDTDT: + dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddA_cs -= dt * ddt + tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_stable_kernel_old( + # Pointers to matrices + x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, + ddAcs_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + # Doing a matmul loop with cumsum later on will cause Triton to crash + # Instead we do just one big matmul + # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # for k in range(0, hdim, BLOCK_SIZE_K): + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) + # acc += tl.dot(dout, x) + # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim + # x_ptrs += BLOCK_SIZE_K * stride_x_hdim + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + acc = tl.dot(dout, x) + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) + acc *= cb + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + acc *= dt_n + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + mask = offs_m[:, None] >= offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + acc = tl.cumsum(acc, axis=1) + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n + tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) + tl.store(ddAcs_ptr, 0.0) + + # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64) + # offs_k = tl.arange(0, BLOCK_SIZE_K) + # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + # dt_ptrs = dt_ptr + offs_n * stride_dt_csize + # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + + # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m + # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n + # for n in range(0, chunk_size_limit_n, 64): + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0) + # acc = tl.dot(dout, x) + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32) + # acc *= cb + # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) + # acc *= dt_n + # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) + # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n + # acc = tl.where(mask, acc, 0.0) + # acc = tl.cumsum(acc, axis=1) + # acc = tl.where(mask, acc, 0.0) + # ddA_cs = tl.sum(acc, axis=0) + # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n) + # # tl.store(ddAcs_ptr, 0.0) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + tl.store(ddA_cumsum_ptr, 0.0) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower + lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M + # lo, hi = 0, chunk_size + for start_n in range(lo, hi, BLOCK_SIZE_N): + start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) + # Doing a matmul loop with cumsum later on will cause Triton to crash + # Instead we do just one big matmul + # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # for k in range(0, hdim, BLOCK_SIZE_K): + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) + # acc += tl.dot(dout, x) + # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim + # x_ptrs += BLOCK_SIZE_K * stride_x_hdim + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) + acc = tl.dot(dout, x) + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= dt_n + # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) + acc *= cb + dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + rowsum_new = rowsum + tl.sum(acc, axis=1) + acc = rowsum[:, None] + tl.cumsum(acc, axis=1) + rowsum = rowsum_new + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) + x_ptrs += BLOCK_SIZE_N * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_N * stride_dt_csize + cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + # Need to zero out the rest, since we'll be summing the rows together + for start_n in range(hi, chunk_size, BLOCK_SIZE_N): + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1) + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_prev_kernel( + # Pointers to matrices + dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nchunks, nheads_ngroups_ratio, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + prev_states = prev_states.to(dout_ptrs.dtype.element_ty) + acc = tl.dot(dout, prev_states) + c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + ddA_cs = tl.sum(acc * c, axis=1) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_m) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + ddA_cs *= scale + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + + +def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Allocates output. + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + ) + return out, out_x + + +def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert B.shape == C.shape + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Allocates output. + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel_wip[grid]( + cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + D.stride(0) if D is not None else 0, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + BLOCK_SIZE_M=128, + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out, out_x + + +def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False): + batch, seqlen, nheads, headdim = x.shape + assert z.shape == x.shape + assert out.shape == x.shape + assert dout.shape == out.shape + nchunks = math.ceil(seqlen / chunk_size) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + if has_ddAcs: + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + if D is not None: + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + if dz is not None: + assert dz.shape == z.shape + else: + dz = torch.empty_like(z) + if recompute_output: + outz = torch.empty_like(x) + dout_x = torch.empty_like(dout) + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dz_kernel[grid_dz]( + dout, out, z, x, D, outz if recompute_output else None, + dz, dout_x, dD, ddA_cumsum if has_ddAcs else None, + chunk_size, headdim, + batch, seqlen, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + z.stride(0), z.stride(1), z.stride(2), z.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + D.stride(0) if D is not None else 0, + *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)), + dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3), + dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + if has_ddAcs else (0, 0, 0, 0)), + D is not None, + D.dim() == 2 if D is not None else True, + has_ddAcs, + BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), + RECOMPUTE_OUTPUT=recompute_output, + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD) + return return_vals if not recompute_output else (*return_vals, outz) + + +def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): + batch, seqlen, nheads, headdim = dout.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + dtype = C.dtype if dtype is None else dtype + dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype) + grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(C.device.index): + _chunk_scan_bwd_dstates_kernel[grid_dstates]( + dout, C, dprev_states, dA_cumsum, seq_idx, + headdim, dstate, chunk_size, + batch, seqlen, nchunks, nheads // ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return dprev_states + + +def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1): + batch, nchunks, nheads, headdim, dstate = prev_states.shape + _, seqlen, _, _ = dout.shape + _, _, _, chunk_size = dA_cumsum.shape + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == (batch, seqlen, nheads, headdim) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if C is not None: + assert C.shape == (batch, seqlen, ngroups, dstate) + C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) + ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3)) + else: + C_strides = (0, 0, 0, 0) + ddA_cumsum_prev = None + ddA_cumsum_prev_strides = (0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32) + grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(dout.device.index): + _chunk_scan_bwd_dc_kernel[grid_dc]( + dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev, + chunk_size, dstate, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + *C_strides, + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), + *ddA_cumsum_prev_strides, + HAS_DDA_CS=ddA_cumsum_prev is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dC = dC.sum(2) + return dC if C is None else (dC, ddA_cumsum_prev) + + +def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if CB is not None: + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4)) + BLOCK_SIZE_M_min = 16 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4)) + else: + CB_strides = (0, 0, 0, 0, 0) + ddA_cumsum = None + ddA_cumsum_strides = (0, 0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32) + grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dcb_kernel[grid_dcb]( + x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + *CB_strides, + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5), + *ddA_cumsum_strides, + HAS_DDA_CS=ddA_cumsum is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dcb = dcb.sum(2) + if ddA_cumsum is not None: + BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return dcb if CB is None else (dcb, ddA_cumsum) + + +def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + # if D is not None: + # BLOCK_SIZE_M_min = 32 + # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32) + # else: + # dD = None + dx = torch.empty_like(x) + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dx_kernel[grid_dx]( + x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + D.stride(0) if D is not None else 0, + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, + D is not None, + D.dim() == 2 if D is not None else True, + ) + # if D is not None: + # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + return dx, ddt.to(dtype=dt.dtype) + + +def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): + """Not numerically stable and should not be used. Leaving here for reference. + """ + + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert ddt.shape == dt.shape + assert out.shape == x.shape + assert dout.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + ddA_cumsum = torch.empty_like(dt) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + if D is not None: # Triton gives wrong results if we write to the same location + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs]( + dout, out, dt, ddt, x, D, ddA_cumsum, dD, + chunk_size, headdim, + batch, seqlen, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + D.stride(0) if D is not None else 0, + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + D is not None, + D.dim() == 2 if D is not None else True, + subtract_ddtdt, + BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return ddA_cumsum, dD + + +def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + assert dA_cumsum.shape == dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + BLOCK_SIZE_M_min = 16 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs]( + x, dout, dt, dA_cumsum, cb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16), + ) + BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum + + +def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + assert dA_cumsum.shape == dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + BLOCK_SIZE_M_min = 32 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( + x, dout, dt, dA_cumsum, cb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum + + +def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None): + batch, nchunks, nheads, headdim, dstate = prev_states.shape + _, seqlen, _, _ = dout.shape + _, _, _, chunk_size = dA_cumsum.shape + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == (batch, seqlen, nheads, headdim) + ngroups = C.shape[2] + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(dout.device.index): + _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs]( + dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev, + chunk_size, dstate, headdim, + batch, seqlen, nchunks, nheads // ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3), + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + return ddA_cumsum_prev + + +class ChunkScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + CB = _bmm_chunk_fwd(C, B, chunk_size) + out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z) + ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z) + return out + + @staticmethod + def backward(ctx, dout): + if dout.stride(-1) != 1: + dout = dout.contiguous() + out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert dout.shape == (batch, seqlen, nheads, headdim) + if z is not None: + dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D) + else: + dz = None + dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype) + dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups) + dC = dC.to(C.dtype) + dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups) + dCB = dCB.to(CB.dtype) + dB = _bmm_chunk_bwd(C, dCB) + dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC) + dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D) + # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. + # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt + if z is not None: + ddA_cumsum -= ddt * dt + else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz + ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D) + ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype) + return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz + + +def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1. + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z) + + +def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), + rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + prev_states.to(C.dtype)) * state_decay_out + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return out if z is None else out * F.silu(z) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py new file mode 100644 index 00000000..4333e6a7 --- /dev/null +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -0,0 +1,866 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + # Matrix dimension + batch, seqlen, nheads, chunk_size, + dt_min, dt_max, + # Strides + stride_dt_batch, stride_dt_seqlen, stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + pid_c = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) + tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_bwd_kernel( + # Pointers to matrices + ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr, + ddt_ptr, dA_ptr, ddt_bias_ptr, + # Matrix dimensions + batch, seqlen, nheads, chunk_size, + dt_min, dt_max, + # Strides + stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize, + stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize, + stride_dt_batch, stride_dt_seqlen, stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head, + stride_dA_head, + stride_ddt_bias_head, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + pid_c = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk + ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize) + ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + ddt = ddA * A[:, None] + ddt_out + dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt_presoftplus = dt + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), ddt) + clamp_mask = (dt < dt_min) | (dt > dt_max) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) + ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0) + ddt = tl.where(clamp_mask, 0.0, ddt) + if DT_SOFTPLUS: + ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt) + tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + dA = tl.sum(ddA * dt, axis=1) + tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) + if HAS_DT_BIAS: + ddt_bias = tl.sum(ddt, axis=1) + tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_state_bwd_dx_kernel( + # Pointers to matrices + x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, + dx_ptr, ddt_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) + if BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) + else: + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + ddA_cs = -(ddt * dt_m) + ddA_cs_last = -tl.sum(ddA_cs) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) + + dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty) + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_state_bwd_db_kernel( + # Pointers to matrices + x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + db_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + if HAS_DDA_CS: + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + dstates = dstates.to(x_ptrs.dtype.element_ty) + db = tl.dot(x, dstates) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + db *= (scale * dt_m)[:, None] + if HAS_DDA_CS: + # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum + ddA_cs = tl.sum(db * b, axis=1) + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + acc += db + x_ptrs += stride_x_head + dstates_ptrs += stride_states_head + dt_ptrs += stride_dt_head + dA_cumsum_ptr += stride_dA_cs_head + dA_cumsum_ptrs += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptrs += stride_ddA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # if HAS_SEQ_IDX: + # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0) + db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate) + tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_state_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) + if BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) + else: + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + acc *= scale[:, None] + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + # ddA_cs = -(ddt * dt_m) + # Triton 2.2.0 errors if we have the cumsum here, so we just write it out + # then call torch.cumsum outside this kernel. + # ddA_cs = tl.cumsum(ddt * dt_m) + ddA_cs = ddt * dt_m + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + + +def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, A, dt_bias, dt_out, dA_cumsum, + batch, seqlen, nheads, chunk_size, + dt_limit[0], dt_limit[1], + dt.stride(0), dt.stride(1), dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None): + batch, seqlen, nheads = dt.shape + _, _, nchunks, chunk_size = ddA.shape + assert ddA.shape == (batch, nheads, nchunks, chunk_size) + assert ddt_out.shape == (batch, nheads, nchunks, chunk_size) + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) + else: + ddt_bias = None + if ddt is not None: + assert ddt.shape == dt.shape + else: + ddt = torch.empty_like(dt) + dA = torch.empty_like(A, dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_bwd_kernel[grid_chunk_cs]( + ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias, + batch, seqlen, nheads, chunk_size, + dt_limit[0], dt_limit[1], + ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3), + ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3), + dt.stride(0), dt.stride(1), dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + ddt.stride(0), ddt.stride(1), ddt.stride(2), + dA.stride(0), + ddt_bias.stride(0) if ddt_bias is not None else 0, + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return ddt, dA, ddt_bias + + +def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, B, states, dt, dA_cumsum, seq_idx, + headdim, dstate, chunk_size, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if dx is not None: + assert dx.shape == x.shape + else: + dx = torch.empty_like(x) + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_dx_kernel[grid_dx]( + x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + ) + return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype) + + +def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + dstate = dstates.shape[-1] + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B is not None: + assert B.shape == (batch, seqlen, ngroups, dstate) + B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3)) + # Use torch.empty since the Triton kernel will call init_to_zero + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + else: + B_strides = (0, 0, 0, 0) + ddA_cumsum = None + ddA_cumsum_strides = (0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32) + grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_db_kernel[grid_db]( + x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum, + chunk_size, dstate, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + *B_strides, + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4), + *ddA_cumsum_strides, + HAS_DDA_CS=ddA_cumsum is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dB = dB.sum(2) + if ddA_cumsum is not None: + # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute + # to the state of the chunk. + # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) + # But it's easier to just do the cumsum for all elements, the result will be the same. + torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum) + return dB if B is None else (dB, ddA_cumsum) + + +def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Use torch.empty since the Triton kernel will call init_to_zero + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs]( + x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + ) + torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) + return ddA_cumsum + + +class ChunkStateFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if B.stride(-1) != 1: + B = B.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32) + ctx.save_for_backward(B, x, dt, dA_cumsum) + return states + + @staticmethod + def backward(ctx, dstates): + B, x, dt, dA_cumsum = ctx.saved_tensors + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if dstates.stride(-1) != 1: + dstates = dstates.contiguous() + dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates) + dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups) + dB = dB.to(B.dtype) + return dB, dx, ddt, ddA_cumsum, None + + +def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32) + + +def chunk_state_ref(B, x, dt, dA_cumsum): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + ngroups = B.shape[2] + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) + B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) + decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py new file mode 100644 index 00000000..58d806b2 --- /dev/null +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -0,0 +1,959 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +from typing import Optional + +import math +from packaging import version + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn + import causal_conv1d_cuda +except ImportError: + causal_conv1d_fn, causal_conv1d_cuda = None, None + +from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd +from src.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd +from src.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db +from src.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable +from src.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref +from src.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd +from src.ops.triton.ssd_state_passing import state_passing, state_passing_ref +from src.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates +from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb +from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable +from src.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref +from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev +from src.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd +from src.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_scan_chunk_state_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr, + b_ptr, dstates_ptr, + dx_ptr, ddt_ptr, dD_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_D_head, + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_m) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) + # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + # However, we're getting error with the Triton compiler 2.1.0 for that code path: + # Unexpected mma -> mma layout conversion + # Triton 2.2.0 fixes this + offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate) + if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates) * scale[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate + acc *= scale[:, None] + + # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + # dt_ptrs = dt_ptr + offs_m * stride_dt_csize + # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + # ddt = tl.sum(acc * x, axis=1) * dt_m + # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit + K_MIN = pid_m * BLOCK_SIZE_M + cb_ptrs += K_MIN * stride_cb_csize_k + dout_ptrs += K_MIN * stride_dout_seqlen + dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize + for k in range(K_MIN, K_MAX, BLOCK_SIZE_K): + k = tl.multiple_of(k, BLOCK_SIZE_K) + # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + mask = k + offs_k[None, :] >= offs_m[:, None] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + dD = tl.sum(dout_res * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + else: + dD = tl.sum(dout_res * x) + tl.store(dD_ptr, dD) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + +def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + if dx is None: + dx = torch.empty_like(x) + else: + assert dx.shape == x.shape + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( + x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + D.stride(0) if D is not None else 0, + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + D is not None, + D.dim() == 2 if D is not None else True, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + IS_TRITON_22=TRITON_22 + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return dx, ddt.to(dtype=dt.dtype), dD + + +def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size) + # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) + # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) + # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) + states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype) + states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] + # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) + return out, out_x, dt, dA_cumsum, states, final_states + + +def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, + dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, + dt_limit=(0.0, float("inf")), + dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False): + if dout.stride(-1) != 1: + dout = dout.contiguous() + batch, seqlen, nheads, headdim = x.shape + nchunks = math.ceil(seqlen / chunk_size) + _, _, ngroups, dstate = B.shape + assert dout.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert C.shape == B.shape + assert out.shape == x.shape + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if dx is not None: + assert dx.shape == x.shape + if dB is not None: + assert dB.shape == B.shape + dB_given = dB + else: + dB_given = torch.empty_like(B) + if dC is not None: + assert dC.shape == C.shape + dC_given = dC + else: + dC_given = torch.empty_like(C) + if dz is not None: + assert z is not None + assert dz.shape == z.shape + if ddt is not None: + assert ddt.shape == dt.shape + ddt_given = ddt + else: + ddt_given = torch.empty_like(dt) + # TD: For some reason Triton (2.1.0 and 2.2.0) errors with + # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why. + dt_in = dt.clone() + dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, + dt_limit=dt_limit) + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + if z is not None: + dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output) + outz = rest[0] if recompute_output else out + else: + dz = None + outz = out + dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype) + # dstates has length nchunks, containing the gradient to initial states at index 0 and + # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1) + # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states + # will be used in matmul in the next kernels. + dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + rearrange(dstates, "... p n -> ... (p n)"), + dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None, + seq_idx=seq_idx, + has_initial_states=initial_states is not None, + dstates_dtype=x.dtype, + states_dtype=x.dtype, + chunk_size=chunk_size, + ) + # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and + # gradient to the final states at index (nchunks - 1) + # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1) + # The final states is not stored. + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate) + dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None + dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx) + # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups) + dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) + # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) + dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups) + # Computing ddA with the dcb kernel is much slower, so we're not using it for now + dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) + # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups) + dCB = dCB.to(CB.dtype) + _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given) + _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given) + # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate + # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16 + if z is None: + dD = dD_from_x + # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. + # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt + # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might + # be a lot of underflow. + + # This is already done as part of bwd_dC kernel + # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx) + ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum + ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1]) + # This is already done as part of bwd_dB kernel + # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx) + # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j] + ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB) + ddA += ddA_next + ddA_prev + + ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given) + + # These 2 lines are just to test ddt and dA being computed by old code + # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z) + # ddt_given.copy_(ddt) + + return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states) + return return_vals if not recompute_output else (*return_vals, outz) + + +def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None): + """ + Argument: + dout: (batch, seqlen, nheads, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size) + A: (nheads) or (dim, dstate) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + import selective_scan + + batch, seqlen, nheads, headdim = x.shape + chunk_size = dt.shape[-1] + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + x = rearrange(x, "b l h p -> b (h p) l") + squeeze_dt = dt.dim() == 4 + if dt.dim() == 4: + dt = repeat(dt, "b h c l -> b h p c l", p=headdim) + dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim) + squeeze_A = A.dim() == 1 + if A.dim() == 1: + A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32) + else: + A = A.to(dtype=torch.float32) + B = rearrange(B, "b l g n -> b g n l") + C = rearrange(C, "b l g n -> b g n l") + if D is not None: + if D.dim() == 2: + D = rearrange(D, "h p -> (h p)") + else: + D = repeat(D, "h -> (h p)", p=headdim) + if z is not None: + z = rearrange(z, "b l h p -> b (h p) l") + + if x.stride(-1) != 1: + x = x.contiguous() + if dt.stride(-1) != 1: + dt = dt.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False) + if z is not None: + out = rest[0] + else: + out = None + + dout = rearrange(dout, "b l h p -> b (h p) l") + + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + _, ddt, dA, *rest = selective_scan.bwd( + x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False, + False # option to recompute out_z, not used here + ) + ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size) + if squeeze_dt: + ddt = ddt.float().sum(dim=2) + if squeeze_A: + dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2)) + return ddt, dA + + +class MambaChunkScanCombinedFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False): + ctx.dt_dtype = dt.dtype + out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit) + ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx) + ctx.dt_softplus = dt_softplus + ctx.chunk_size = chunk_size + ctx.dt_limit = dt_limit + ctx.return_final_states = return_final_states + return out if not return_final_states else (out, final_states) + + @staticmethod + def backward(ctx, dout, *args): + out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) + return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None + + +def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states) + + +def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + if seqlen % chunk_size != 0: + dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) + dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) + dt = dt.float() # We want high precision for this before cumsum + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "h -> h 1 1") + if dt_softplus: + dt = F.softplus(dt) + dA = dt * rearrange(A, "h -> h 1 1") + dA = dt * rearrange(A, "h -> h 1 1") + dA_cumsum = torch.cumsum(dA, dim=-1) + # 1. Compute the state for each chunk + states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True) + # 2. Pass the state to all the chunks by weighted cumsum. + states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], + "... (p n) -> ... p n", n=dstate) + # 3. Compute the output for each chunk + out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z) + return out + + +def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + if seqlen % chunk_size != 0: + dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) + dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) + dt = dt.float() # We want high precision for this before cumsum + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "h -> h 1 1") + if dt_softplus: + dt = F.softplus(dt) + dA = dt * rearrange(A, "h -> h 1 1") + dA_cumsum = torch.cumsum(dA, dim=-1) + # 1. Compute the state for each chunk + states = chunk_state_ref(B, x, dt, dA_cumsum) + states_dtype = states.dtype + if states.dtype not in [torch.float32, torch.float64]: + states = states.to(torch.float32) + # 2. Pass the state to all the chunks by weighted cumsum. + # state_passing_ref is much less numerically stable + states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], + "... (p n) -> ... p n", n=dstate) + states = states.to(states_dtype) + # 3. Compute the output for each chunk + out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z) + return out + + +def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) + A: (nheads) or (dim, dstate) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) or (nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + from src.ops.selective_scan_interface import selective_scan_fn + + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + x = rearrange(x, "b l h p -> b (h p) l") + if dt.dim() == 3: + dt = repeat(dt, "b l h -> b l h p", p=headdim) + dt = rearrange(dt, "b l h p -> b (h p) l") + if A.dim() == 1: + A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32) + else: + A = A.to(dtype=torch.float32) + B = rearrange(B, "b l g n -> b g n l") + C = rearrange(C, "b l g n -> b g n l") + if D is not None: + if D.dim() == 2: + D = rearrange(D, "h p -> (h p)") + else: + D = repeat(D, "h -> (h p)", p=headdim) + if z is not None: + z = rearrange(z, "b l h p -> b (h p) l") + if dt_bias is not None: + if dt_bias.dim() == 1: + dt_bias = repeat(dt_bias, "h -> h p", p=headdim) + dt_bias = rearrange(dt_bias, "h p -> (h p)") + if dt_limit != (0.0, float("inf")): + if dt_bias is not None: + dt = dt + rearrange(dt_bias, "d -> d 1") + if dt_softplus: + dt = F.softplus(dt) + dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype) + dt_bias = None + dt_softplus = None + out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus) + return rearrange(out, "b (h p) l -> b l h p", p=headdim) + + +def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None, + dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + activation="silu", headdim=None, ngroups=1): + """ + Argument: + xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) + A: (nheads) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, dim) + dt_bias: (nheads) or (nheads, headdim) + headdim: if D is 1D and z is None, headdim must be passed in + Return: + out: (batch, seqlen, dim) + """ + batch, seqlen, nheads = dt.shape[:3] + assert nheads % ngroups == 0 + if z is not None: + dim = z.shape[-1] + assert dim % nheads == 0 + headdim = dim // nheads + else: + if D.dim() == 1: + assert headdim is not None + else: + headdim = D.shape[1] + dim = nheads * headdim + xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), + "b d s -> b s d") + dstate = (xBC.shape[-1] - dim) // ngroups // 2 + x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None + out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + return rearrange(out, "b s h p -> b s (h p)") + + +class MambaSplitConv1dScanCombinedFn(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", + rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, + ngroups=1, norm_before_gate=True): + assert activation in [None, "silu", "swish"] + if D.dim() == 1: + assert headdim is not None + nheads, = D.shape + else: + nheads, headdim = D.shape + batch, seqlen, _ = zxbcdt.shape + dim = nheads * headdim + assert nheads % ngroups == 0 + dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2 + d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2 + assert d_nonssm >= 0 + assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads) + assert dt_bias.shape == (nheads,) + assert A.shape == (nheads,) + zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1) + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + xBC_conv = rearrange( + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]), + "b d s -> b s d" + ) + x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None + if rmsnorm_weight is None: + out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) + out = rearrange(out, "b s h p -> b s (h p)") + rstd = None + if d_nonssm > 0: + out = torch.cat([_swiglu_fwd(zx0), out], dim=-1) + else: + out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) + # reshape input data into 2D tensor + x_rms = rearrange(out_x, "b s h p -> (b s) (h p)") + z_rms = rearrange(z, "b s h p -> (b s) (h p)") + rmsnorm_weight = rmsnorm_weight.contiguous() + if d_nonssm == 0: + out = None + else: + out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device) + out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d") + _swiglu_fwd(zx0, out=out01[..., :d_nonssm]) + out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out, + group_size=dim // ngroups, + norm_before_gate=norm_before_gate, is_rms_norm=True) + if d_nonssm == 0: + out = rearrange(out, "(b s) d -> b s d", b=batch) + else: + out = out01 + ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None + if outproj_weight is not None: + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + out, outproj_weight = out.to(dtype), outproj_weight.to(dtype) + outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None + out = F.linear(out, outproj_weight, outproj_bias) + else: + assert outproj_bias is None + ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias, + out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias) + ctx.dt_limit = dt_limit + ctx.return_final_states = return_final_states + ctx.activation = activation + ctx.rmsnorm_eps = rmsnorm_eps + ctx.norm_before_gate = norm_before_gate + ctx.chunk_size = chunk_size + ctx.headdim = headdim + ctx.ngroups = ngroups + return out if not return_final_states else (out, final_states) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + headdim = ctx.headdim + nheads = D.shape[0] + dim = nheads * headdim + assert nheads % ctx.ngroups == 0 + dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2 + d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2 + assert d_nonssm >= 0 + recompute_output = outproj_weight is not None + if recompute_output: + out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype) + out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1) + zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) + # Recompute x, B, C + xBC_conv = rearrange( + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), + "b d s -> b s d" + ) + x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups) + dzxbcdt = torch.empty_like(zxbcdt) + dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) + dxBC = torch.empty_like(xBC) + dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) + dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads) + dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups) + dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups) + if outproj_weight is not None: + dout_og = dout + dout = F.linear(dout, outproj_weight.t()) + if d_nonssm > 0: + dout0, dout = dout.split([d_nonssm, dim], dim=-1) + _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute) + dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim) + if rmsnorm_weight is None: + dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads) + dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd( + dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output + ) + out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None + drmsnorm_weight = None + else: + batch = dout.shape[0] + dy_rms = rearrange(dout, "b s h p -> (b s) (h p)") + dz = rearrange(dz, "b l d -> (b l) d") + x_rms = rearrange(out, "b s h p -> (b s) (h p)") + z_rms = rearrange(z, "b s h p -> (b s) (h p)") + out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None + dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None) + out_for_linear = out_recompute if recompute_output else None + dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim) + dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd( + dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC + ) + + if outproj_weight is not None: + doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear) + doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None + else: + doutproj_weight, doutproj_bias = None, None + dxBC_given = rearrange(dxBC_given, "b s d -> b d s") + dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, + rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"] + ) + dxBC_given = rearrange(dxBC_given, "b d s -> b s d") + return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None + + +def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): + """ + Argument: + zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt_bias: (nheads,) + A: (nheads) + D: (nheads, headdim) or (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen), int32 + rmsnorm_weight: (dim,) + outproj_weight: (out_dim, dim) + outproj_bias: (out_dim,) + headdim: if D is 1D, headdim must be passed in + norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) + Return: + out: (batch, seqlen, dim) + """ + return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate) + + +def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): + """ + Argument: + zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim + conv1d_weight: (dim + 2 * ngroups * dstate, width) + conv1d_bias: (dim + 2 * ngroups * dstate,) + dt_bias: (nheads,) + A: (nheads) + D: (nheads, headdim) or (nheads,) + rmsnorm_weight: (dim,) + outproj_weight: (out_dim, dim) + outproj_bias: (out_dim,) + headdim: if D is 1D, headdim must be passed in + norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) + Return: + out: (batch, seqlen, dim) + """ + if D.dim() == 1: + assert headdim is not None + nheads, = D.shape + else: + nheads, headdim = D.shape + assert nheads % ngroups == 0 + batch, seqlen, _ = zxbcdt.shape + dim = nheads * headdim + dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2 + assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) + assert dt_bias.shape == (nheads,) + assert A.shape == (nheads,) + if rmsnorm_weight is not None: + assert rmsnorm_weight.shape == (dim,) + z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1) + xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), + "b d s -> b s d") + x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1) + x = rearrange(x, "b l (h p) -> b l h p", h=nheads) + B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) + C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) + z = rearrange(z, "b l (h p) -> b l h p", h=nheads) + out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), + z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit) + out = rearrange(out, "b s h p -> b s (h p)") + if rmsnorm_weight is not None: + out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps, + norm_before_gate=norm_before_gate) + if outproj_weight is not None: + out = F.linear(out, outproj_weight, outproj_bias) + return out + diff --git a/mamba_ssm/ops/triton/ssd_state_passing.py b/mamba_ssm/ops/triton/ssd_state_passing.py new file mode 100644 index 00000000..63863b82 --- /dev/null +++ b/mamba_ssm/ops/triton/ssd_state_passing.py @@ -0,0 +1,348 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, + # Matrix dimensions + dim, nchunks, seqlen, chunk_size, + # Strides + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim, + stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, + stride_final_states_batch, stride_final_states_head, stride_final_states_dim, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_initstates_batch, stride_initstates_head, stride_initstates_dim, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_bwd_kernel( + # Pointers to matrices + dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr, + dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr, + # Matrix dimensions + dim, nchunks, seqlen, chunk_size, + # Strides + stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim, + stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, + stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim, + # Meta-parameters + CONVERT_STATES: tl.constexpr, + HAS_DFINAL_STATES: tl.constexpr, + HAS_DINITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk + ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk + dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk + if CONVERT_STATES: + states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk + if HAS_DFINAL_STATES: + dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head + if HAS_DINITSTATES: + dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + dout_ptrs = dout_ptr + offs_m * stride_dout_dim + if CONVERT_STATES: + states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim + + if HAS_DFINAL_STATES: + dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + tl.store(dstates_ptrs, dstates, mask=offs_m < dim) + if HAS_SEQ_IDX: + seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen) + dstates_ptrs -= stride_dstates_chunk + for c in range(nchunks - 1): + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen)) + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + seq_idx = seq_idx_new + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if CONVERT_STATES: + tl.store(states_converted_ptrs, out, mask=offs_m < dim) + ddA = tl.sum(out * dstates) * scale + tl.store(ddA_cs_ptr, ddA) + dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dstates = scale * dstates + dout + tl.store(dstates_ptrs, dstates, mask=offs_m < dim) + dout_ptrs -= stride_dout_chunk + dstates_ptrs -= stride_dstates_chunk + dA_cs_ptr -= stride_dA_cs_chunk + ddA_cs_ptr -= stride_ddA_cs_chunk + out_ptrs -= stride_out_chunk + if CONVERT_STATES: + states_converted_ptrs -= stride_out_chunk + if CONVERT_STATES: + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + tl.store(states_converted_ptrs, out, mask=offs_m < dim) + if not HAS_DINITSTATES: + tl.store(ddA_cs_ptr, 0.0) + else: + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + scale = tl.where(seq_idx == 0, scale, 0.0) + out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + ddA = tl.sum(out * dstates) * scale + tl.store(ddA_cs_ptr, ddA) + dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dstates = scale * dstates + dout + tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim) + + +def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, + out_dtype=None): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + assert initial_states.shape == (batch, nheads, dim) + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx, + dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, + states.stride(0), states.stride(1), states.stride(2), states.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + final_states.stride(0), final_states.stride(1), final_states.stride(2), + dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) + if initial_states is not None else (0, 0, 0)), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out, final_states + + +def _state_passing_bwd( + states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None, + dstates_dtype=None, states_dtype=None, chunk_size=None +): + """ + states contains the initial_states at index 0. The final states are not included in states. + """ + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + assert dout.shape == (batch, nchunks, nheads, dim) + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) + if states_dtype is not None and states_dtype != states.dtype: + states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) + assert states_converted.stride() == states.stride() + else: + states_converted = None + if has_initial_states: + dinitstates = torch.empty_like(dstates[:, 0]) + else: + dinitstates = None + if dfinal_states is not None: + assert dfinal_states.shape == (batch, nheads, dim) + BLOCK_SIZE_min = 64 + n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min + ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks, + dtype=torch.float32, device=dA_chunk_cumsum.device) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(dout.device.index): + _state_passing_bwd_kernel[grid]( + dout, states, dA_chunk_cumsum, dfinal_states, seq_idx, + dstates, ddA_chunk_cumsum, dinitstates, states_converted, + dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), + dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), + *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2)) + if dfinal_states is not None else (0, 0, 0)), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), + ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1), + *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2)) + if dinitstates is not None else (0, 0, 0)), + CONVERT_STATES=states_converted is not None, + HAS_DFINAL_STATES=dfinal_states is not None, + HAS_DINITSTATES=dinitstates is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"] + n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype) + if states_dtype is not None and states_dtype == states.dtype: + states_converted = states + return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted) + + +class StatePassingFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, states, dA_chunk_cumsum, initial_states=None): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if states.stride(-1) != 1: + states = states.contiguous() + out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states) + ctx.save_for_backward(out, dA_chunk_cumsum) + ctx.has_initial_states = initial_states is not None + return out, final_states + + @staticmethod + def backward(ctx, dout, dfinal_states): + out, dA_chunk_cumsum = ctx.saved_tensors + batch, nchunks, nheads, dim = out.shape + assert dout.shape == (batch, nchunks, nheads, dim) + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + assert dfinal_states.shape == (batch, nheads, dim) + if dout.stride(-1) != 1: + dout = dout.contiguous() + dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd( + out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states + ) + return dstates, ddA_chunk_cumsum, dinitstates + + +def state_passing(states, dA_chunk_cumsum, initial_states=None): + """ + Argument: + states: (batch, nchunks, nheads, dim) + dA_chunk_cumsum: (batch, nheads, nchunks) + initial_states: (batch, nheads, dim) + Return: + out: (batch, nchunks, nheads, dim) + final_states: (batch, nheads, dim) + """ + return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states) + + +def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): + """ + Argument: + states: (batch, nchunks, nheads, dim) + dA_chunk_cumsum: (batch, nheads, nchunks) + initial_states: (batch, nheads, dim) + Return: + out: (batch, nchunks, nheads, dim) + final_states: (batch, nheads, dim) + """ + if initial_states is None: + initial_states = torch.zeros_like(states[:, 0]) + states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1) + dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0)) + dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1) + nchunks = dA_chunk_cumsum.shape[-1] + # (batch, nheads, nchunks, nchunks) + dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :] + # (batch, nheads, nchunks, nchunks) + decay_chunk = torch.exp(dt_chunk_segment_sum) + causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0) + decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) + out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states) + return out[:, :-1], out[:, -1]