Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][Model] Jamba - Handle aborted requests, Add tests and fix cleanup bug #6425

Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from tests.models.utils import check_outputs_equal
from vllm.worker.model_runner import _get_graph_batch_size

MODELS = ["ai21labs/Jamba-tiny-random"]
Expand Down Expand Up @@ -34,6 +35,34 @@ def test_models(
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])

batched_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)

check_outputs_equal(
outputs_0_lst=for_loop_outputs,
outputs_1_lst=batched_outputs,
name_0="for_loop_vllm",
name_1="batched_vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
Expand All @@ -60,6 +89,60 @@ def test_mamba_cache_cg_padding(
"Could be related to mamba cache not padded correctly")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models_preemption_recompute(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)

vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum mamba block capacity.
# This could generally happen due to the fact that Jamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
Expand Down
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs():
if seq.is_finished():
continue
Expand Down
32 changes: 20 additions & 12 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (supports_lora,
from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -66,10 +67,10 @@ def _get_quantization_config(


def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
) -> Dict[str, Any]:
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}

Expand All @@ -90,13 +91,19 @@ def _get_model_initialization_kwargs(

extra_kwargs["multimodal_config"] = multimodal_config

if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config

return extra_kwargs


def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module:
def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config)
Expand All @@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, multimodal_config))
model_class, lora_config, multimodal_config,
scheduler_config))


class BaseModelLoader(ABC):
Expand Down Expand Up @@ -266,7 +274,7 @@ def load_model(self, *, model_config: ModelConfig,
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
cache_config, scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
Expand Down Expand Up @@ -306,7 +314,7 @@ def load_model(self, *, model_config: ModelConfig,
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
cache_config, scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
Expand Down
48 changes: 47 additions & 1 deletion vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing_extensions import TypeGuard

from vllm.config import LoRAConfig, MultiModalConfig
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -142,3 +142,49 @@ def _supports_lora(
return isinstance(model, _SupportsLoRAType)

return isinstance(model, SupportsLoRA)


@runtime_checkable
class HasInnerState(Protocol):
"""The interface required for all models that has inner state."""

has_inner_state: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
"""

def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...


@runtime_checkable
class _HasInnerStateType(Protocol):
has_inner_state: ClassVar[Literal[True]]

def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...


@overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
...


@overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
...


def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)

return isinstance(model, HasInnerState)
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
36 changes: 25 additions & 11 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
Expand All @@ -32,10 +32,12 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import HasInnerState
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -612,7 +614,7 @@ def forward(
return hidden_states


class JambaForCausalLM(nn.Module):
class JambaForCausalLM(nn.Module, HasInnerState):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -640,9 +642,11 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None:
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = JambaModel(config,
cache_config=cache_config,
quant_config=quant_config,
Expand Down Expand Up @@ -689,16 +693,17 @@ def forward(self,
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])

request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids)
(
current_seqlen_agnostic_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size)
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size,
finished_requests_ids)
else:
# CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = (
Expand Down Expand Up @@ -760,10 +765,15 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
return indices_for_current_run

def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
finished_requests_ids: List[str]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
indices_for_current_run = []
for request_id, seqs_id in request_ids_to_seq_ids.items():
if request_id in finished_requests_ids:
# Do not allocate cache for requests that run
# and finish right after
continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
request_id, seqs_id)
## Pad the batch in case of running batch that was not captured via CG
Expand All @@ -787,16 +797,17 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size)
cg_batch_size,
finished_requests_ids)
self.current_indices = indices
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)

for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"],
Expand Down Expand Up @@ -860,9 +871,12 @@ def _prepare_mamba_cache(self):
layers_type = self.config.layers_block_type
mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None

for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
Expand Down
Loading