Skip to content

Chunk-Wise Inference does not match Full-Length Inference #641

Open
@WorldEditors

Description

@WorldEditors

The reproduction code:

from mamba_ssm import Mamba2
from mamba_ssm.utils.generation import InferenceParams
import torch

bsz = 1
seq_len = 256
seg_len = 64
dim = 512
seg_num = (seq_len - 1) // seg_len + 1

x = torch.randn(bsz, seq_len, dim).to("cuda")

inference_params = InferenceParams(max_seqlen=seq_len, max_batch_size=bsz)

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    layer_idx=0
).to("cuda")

y = model(x,inference_params=inference_params)


inference_params.key_value_memory_dict[0][0].zero_()
inference_params.key_value_memory_dict[0][1].zero_()

for i in range(seg_num):
    b = i * seg_len
    e = min(b + seg_len, seq_len)
    yseg = model(x[:, b:e], inference_params=inference_params)
    error = y[:, b:e] - yseg
    print(f"Seg {i} error:", torch.sum(error ** 2))

assert y.shape == x.shape

Expect output error close to zero,
but instead we get:

Seg 0 error: tensor(1.0139e-08, device='cuda:0', grad_fn=<SumBackward0>)
Seg 1 error: tensor(158.5197, device='cuda:0', grad_fn=<SumBackward0>)
Seg 2 error: tensor(141.9800, device='cuda:0', grad_fn=<SumBackward0>)
Seg 3 error: tensor(161.3486, device='cuda:0', grad_fn=<SumBackward0>)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions