Skip to content

Commit

Permalink
more comments
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Dec 12, 2024
1 parent b2dc5ca commit 154255a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
2 changes: 0 additions & 2 deletions tests/models/decoder_only/language/test_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
def generate_greedy(model_name, example_prompts, max_tokens):
# Create a text generation pipeline
# - in the original test_mamba.py they do not put the model to cuda
# maybe this affects the test.
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

Expand Down
32 changes: 23 additions & 9 deletions vllm/model_executor/layers/mamba/ops/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,36 @@ def _mamba_chunk_scan_combined_fwd(x,
D = D.contiguous()
if initial_states is not None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)

# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
# which has a minimal implementation to understand the below operations
# - as explained by the blog, mamba is a special case of causal attention
# - the idea is to chunk the attention matrix and compute each
# submatrix seperately using different optimizations.
# - see the blog and paper for a visualization of the submatrices
# which we refer to in the comments below

# 1. Compute chunked cumsum of A * dt
# - here dt may go through a softplus activation
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
A,
chunk_size,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
dt_limit=dt_limit)

# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
states = _chunk_state_fwd(B,
x,
dt,
dA_cumsum,
seq_idx=seq_idx,
states_in_fp32=True)
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)

# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum[:, :, :, -1],
Expand All @@ -96,13 +107,16 @@ def _mamba_chunk_scan_combined_fwd(x,
out_dtype=C.dtype)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)

# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C,
B,
chunk_size,
seq_idx=seq_idx,
output_dtype=torch.float32)

# 5. Scan and compute the diagonal blocks, taking into
# account past causal states.
out, out_x = _chunk_scan_fwd(CB,
x,
dt,
Expand Down

0 comments on commit 154255a

Please sign in to comment.