From 0eb1b2aad9ae9bba70349236194e6e7d86772e6d Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 3 Nov 2024 19:15:24 +0200 Subject: [PATCH] Revert changes Signed-off-by: mzusman --- tests/kernels/test_causal_conv1d.py | 30 +++++++++---------- .../layers/mamba/mamba_mixer.py | 7 +++-- .../layers/mamba/ops/causal_conv1d.py | 3 -- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 43222b5c30dcd..f9b11018288be 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -2,15 +2,13 @@ import pytest import torch +import torch.nn.functional as F from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_native, causal_conv1d_update, - causal_conv1d_update_native) - - + causal_conv1d_fn, causal_conv1d_update) from vllm.platforms import current_platform @@ -188,7 +186,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, has_initial_state=torch.ones(batch, dtype=torch.bool, device=x.device)) - out_ref, final_states_ref = causal_conv1d_native( + out_ref, final_states_ref = causal_conv1d_ref( x_ref, weight_ref, bias_ref, @@ -240,11 +238,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, weight, bias, activation=activation) - out_ref = causal_conv1d_update_native(x_ref, - conv_state_ref, - weight, - bias, - activation=activation) + out_ref = causal_conv1d_update_ref(x_ref, + conv_state_ref, + weight, + bias, + activation=activation) assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) @@ -312,11 +310,11 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, activation=activation, conv_state_indices=padded_state_indices, pad_slot_id=PAD_SLOT_ID) - out_ref = causal_conv1d_update_native(x_ref[:batch_size], - conv_state_ref, - weight, - bias, - activation=activation) + out_ref = causal_conv1d_update_ref(x_ref[:batch_size], + conv_state_ref, + weight, + bias, + activation=activation) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @@ -407,7 +405,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, if padded_state_indices[i] == PAD_SLOT_ID: continue out_ref_b.append( - causal_conv1d_native( + causal_conv1d_ref( x_s, weight_ref, bias_ref, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index b200d3ebf9c29..8ef0a6cdf2c52 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -7,7 +7,8 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) @@ -59,8 +60,8 @@ def __init__(self, self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.in_proj = MergedColumnParallelLinear(hidden_size, - [intermediate_size] * 2, - bias=use_bias) + [intermediate_size] * 2, + bias=use_bias) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( intermediate_size, diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index c1381762ce2a8..be5639df985fa 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -4,7 +4,6 @@ from typing import Optional import torch -import torch.nn.functional as F from vllm import _custom_ops as ops from vllm.attention.backends.utils import PAD_SLOT_ID @@ -101,5 +100,3 @@ def causal_conv1d_update(x: torch.Tensor, if unsqueeze: x = x.squeeze(-1) return x - -