Skip to content

Commit

Permalink
Revert changes
Browse files Browse the repository at this point in the history
Signed-off-by: mzusman <[email protected]>
  • Loading branch information
mzusman committed Nov 3, 2024
1 parent cdb6773 commit 0eb1b2a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 22 deletions.
30 changes: 14 additions & 16 deletions tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import pytest
import torch
import torch.nn.functional as F

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_native, causal_conv1d_update,
causal_conv1d_update_native)


causal_conv1d_fn, causal_conv1d_update)
from vllm.platforms import current_platform


Expand Down Expand Up @@ -188,7 +186,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
out_ref, final_states_ref = causal_conv1d_native(
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
Expand Down Expand Up @@ -240,11 +238,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
weight,
bias,
activation=activation)
out_ref = causal_conv1d_update_native(x_ref,
conv_state_ref,
weight,
bias,
activation=activation)
out_ref = causal_conv1d_update_ref(x_ref,
conv_state_ref,
weight,
bias,
activation=activation)

assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
Expand Down Expand Up @@ -312,11 +310,11 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = causal_conv1d_update_native(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)

assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
Expand Down Expand Up @@ -407,7 +405,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_native(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
Expand Down Expand Up @@ -59,8 +60,8 @@ def __init__(self,
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)
[intermediate_size] * 2,
bias=use_bias)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
intermediate_size,
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Optional

import torch
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
Expand Down Expand Up @@ -101,5 +100,3 @@ def causal_conv1d_update(x: torch.Tensor,
if unsqueeze:
x = x.squeeze(-1)
return x


0 comments on commit 0eb1b2a

Please sign in to comment.