From 162361f4e2f1a68f84150c55cb7ac6980bc51e4d Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 30 Oct 2024 13:33:33 +0200 Subject: [PATCH 1/4] Mamba test relive bfloat16 tolerence constraint to match update with update, and small fix in causal_conv1d kernel Signed-off-by: mzusman --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 7 +++++-- tests/kernels/test_mamba_ssm.py | 6 +++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 3a464c5f327ad..cbe0d1e9a99c0 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -446,9 +446,12 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { } else { // in case the final state is in between the threads data - reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; - reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); + if ((offset + kWidth - 2) >= kNElts){ + // do not load to index 1 if we're not gonna read from there + reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; + } + reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; #pragma unroll for (int w = 0; w < kWidth - 1; ++w){ conv_states[w] = x_vals_load[offset + w ]; diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index bf7ff3b5c59b8..ad05a97685351 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: - rtol, atol = 7e-2, 7e-2 + rtol, atol = 1e-1, 1e-1 if torch.version.hip: atol *= 2 # set seed @@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, dt_bias=dt_bias, dt_softplus=True) - print("Output diff max", (out - out_ref[0]).max()) - print("Output diff mean", (out - out_ref[0]).mean()) + print("Output diff max", (out[:batch_size] - out_ref).max()) + print("Output diff mean", (out[:batch_size] - out_ref).mean()) print("Output state diff max", (state[state_indices, :] - state_ref).max()) print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) From 1df71e263024c66de3b98491fd252b4b156e5963 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 31 Oct 2024 15:20:15 +0200 Subject: [PATCH 2/4] Add another fix to causal_conv1d, where final state data is seperated between 2 chunks Signed-off-by: mzusman --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 25 +++++++++++++++++++++++ tests/kernels/test_causal_conv1d.py | 7 +++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index cbe0d1e9a99c0..e2458aa6823ad 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); } out += kChunkSize; + + int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); + // in case the final state is seperated between the last "smem_exchange" and + // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), + // (which occurs when `final_state_position` is a non-positivie index) + // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it + if (final_state_position < 0 && seqlen > kWidth){ + input_t vals_load[kNElts] = {0}; + if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ + // chunk = n_chunks - 2, a segment of the final state sits in the last index + reinterpret_cast(vals_load)[0] = smem_exchange[kNThreads - 1]; + #pragma unroll + for (int w = 0; w < -final_state_position; ++w){ + conv_states[w] = vals_load[kNElts + final_state_position + w]; + } + } + if ((chunk == n_chunks - 1) && tidx == 0){ + // chunk = n_chunks - 1, the second segment of the final state first positions + reinterpret_cast(vals_load)[0] = smem_exchange[0]; + for (int w = -final_state_position; w < kWidth - 1; ++w){ + conv_states[w] = vals_load[w + final_state_position]; + } + return; + } + } } // Final state is stored in the smem_exchange last token slot, // in case seqlen < kWidth, we would need to take the final state from the diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 96bfe06d74ae5..f9b11018288be 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize( - 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) + 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]) @pytest.mark.parametrize('dim', [64]) @pytest.mark.parametrize('batch', [1]) def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, @@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, unpadded_out = out[:, :out_ref_tensor.shape[-1]] assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) - assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) + assert torch.allclose(final_states[state_indices], + final_states_ref[state_indices], + rtol=rtol, + atol=atol) causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), padded_state_indices, has_initial_states, From 1376762b66ae32ff387b64d33d56f9a098911f98 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 31 Oct 2024 15:22:04 +0200 Subject: [PATCH 3/4] Format Signed-off-by: mzusman --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index e2458aa6823ad..4f71de15cc992 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -420,7 +420,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { out += kChunkSize; int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); - // in case the final state is seperated between the last "smem_exchange" and + // in case the final state is separated between the last "smem_exchange" and // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), // (which occurs when `final_state_position` is a non-positivie index) // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it From ac66eafba1304cb544142bb7419e6461024227da Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 31 Oct 2024 15:36:49 +0200 Subject: [PATCH 4/4] Add more explaination to the illegal access error Signed-off-by: mzusman --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 4f71de15cc992..498d069c05f0d 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -472,8 +472,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { else { // in case the final state is in between the threads data const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); - if ((offset + kWidth - 2) >= kNElts){ - // do not load to index 1 if we're not gonna read from there + if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){ + // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a + // illegal access error on H100. + // Therefore, we access last_thread + 1, only if the final state data sits there reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; } reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread];