From 154255a662d81cb152949c8712dd8ff664027bc0 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 12 Dec 2024 07:11:31 +0000 Subject: [PATCH] more comments Signed-off-by: Yu Chin Fabian Lim --- .../decoder_only/language/test_bamba.py | 2 -- .../layers/mamba/ops/ssd_combined.py | 32 +++++++++++++------ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/models/decoder_only/language/test_bamba.py b/tests/models/decoder_only/language/test_bamba.py index d266135360563..96efdc59081d3 100644 --- a/tests/models/decoder_only/language/test_bamba.py +++ b/tests/models/decoder_only/language/test_bamba.py @@ -20,8 +20,6 @@ # choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. def generate_greedy(model_name, example_prompts, max_tokens): # Create a text generation pipeline - # - in the original test_mamba.py they do not put the model to cuda - # maybe this affects the test. tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 90854fd0c0a10..5c468662bb4a5 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -67,25 +67,36 @@ def _mamba_chunk_scan_combined_fwd(x, D = D.contiguous() if initial_states is not None: assert initial_states.shape == (batch, nheads, headdim, dstate) - # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size) - # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) - # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix seperately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) - # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) - # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) - # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], @@ -96,13 +107,16 @@ def _mamba_chunk_scan_combined_fwd(x, out_dtype=C.dtype) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) - # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) - # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) + + # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. out, out_x = _chunk_scan_fwd(CB, x, dt,