Open
Description
The reproduction code:
from mamba_ssm import Mamba2
from mamba_ssm.utils.generation import InferenceParams
import torch
bsz = 1
seq_len = 256
seg_len = 64
dim = 512
seg_num = (seq_len - 1) // seg_len + 1
x = torch.randn(bsz, seq_len, dim).to("cuda")
inference_params = InferenceParams(max_seqlen=seq_len, max_batch_size=bsz)
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
layer_idx=0
).to("cuda")
y = model(x,inference_params=inference_params)
inference_params.key_value_memory_dict[0][0].zero_()
inference_params.key_value_memory_dict[0][1].zero_()
for i in range(seg_num):
b = i * seg_len
e = min(b + seg_len, seq_len)
yseg = model(x[:, b:e], inference_params=inference_params)
error = y[:, b:e] - yseg
print(f"Seg {i} error:", torch.sum(error ** 2))
assert y.shape == x.shape
Expect output error close to zero,
but instead we get:
Seg 0 error: tensor(1.0139e-08, device='cuda:0', grad_fn=<SumBackward0>)
Seg 1 error: tensor(158.5197, device='cuda:0', grad_fn=<SumBackward0>)
Seg 2 error: tensor(141.9800, device='cuda:0', grad_fn=<SumBackward0>)
Seg 3 error: tensor(161.3486, device='cuda:0', grad_fn=<SumBackward0>)
Metadata
Metadata
Assignees
Labels
No labels