Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk-Wise Inference does not match Full-Length Inference #641

Open
WorldEditors opened this issue Dec 9, 2024 · 0 comments
Open

Chunk-Wise Inference does not match Full-Length Inference #641

WorldEditors opened this issue Dec 9, 2024 · 0 comments

Comments

@WorldEditors
Copy link

The reproduction code:

from mamba_ssm import Mamba2
from mamba_ssm.utils.generation import InferenceParams
import torch

bsz = 1
seq_len = 256
seg_len = 64
dim = 512
seg_num = (seq_len - 1) // seg_len + 1

x = torch.randn(bsz, seq_len, dim).to("cuda")

inference_params = InferenceParams(max_seqlen=seq_len, max_batch_size=bsz)

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    layer_idx=0
).to("cuda")

y = model(x,inference_params=inference_params)


inference_params.key_value_memory_dict[0][0].zero_()
inference_params.key_value_memory_dict[0][1].zero_()

for i in range(seg_num):
    b = i * seg_len
    e = min(b + seg_len, seq_len)
    yseg = model(x[:, b:e], inference_params=inference_params)
    error = y[:, b:e] - yseg
    print(f"Seg {i} error:", torch.sum(error ** 2))

assert y.shape == x.shape

Expect output error close to zero,
but instead we get:

Seg 0 error: tensor(1.0139e-08, device='cuda:0', grad_fn=<SumBackward0>)
Seg 1 error: tensor(158.5197, device='cuda:0', grad_fn=<SumBackward0>)
Seg 2 error: tensor(141.9800, device='cuda:0', grad_fn=<SumBackward0>)
Seg 3 error: tensor(161.3486, device='cuda:0', grad_fn=<SumBackward0>)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant