Skip to content

Commit

Permalink
[BugFix][Kernel]: fix illegal memory access in causal_conv1d when con…
Browse files Browse the repository at this point in the history
…v_states is None (#10928)

Signed-off-by: xffxff <[email protected]>
  • Loading branch information
xffxff authored Dec 7, 2024
1 parent c889d58 commit 78029b3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
2 changes: 1 addition & 1 deletion csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positivie index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if (final_state_position < 0 && seqlen > kWidth){
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
input_t vals_load[kNElts] = {0};
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
Expand Down
39 changes: 22 additions & 17 deletions tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,14 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("has_initial_state", [True, False])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
itype):
has_initial_state, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
Expand All @@ -167,11 +168,18 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,

weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
if has_initial_state:
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
has_initial_state_tensor = torch.ones(batch,
dtype=torch.bool,
device=x.device)
else:
initial_states = None
has_initial_state_tensor = None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
Expand All @@ -183,31 +191,28 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
has_initial_state=has_initial_state_tensor)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=True,
activation=activation)
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
if has_initial_state:
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
has_initial_state=has_initial_state_tensor)


@pytest.mark.parametrize("itype", [torch.bfloat16])
Expand Down

0 comments on commit 78029b3

Please sign in to comment.