Skip to content

Commit

Permalink
[BugFix][Model] Jamba - Handle aborted requests, Add tests and fix cl…
Browse files Browse the repository at this point in the history
…eanup bug (vllm-project#6425)

Co-authored-by: Mor Zusman <[email protected]>
  • Loading branch information
2 people authored and jimpang committed Jul 24, 2024
1 parent 76bbb5d commit 9feb0a4
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 24 deletions.
83 changes: 83 additions & 0 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from tests.models.utils import check_outputs_equal
from vllm.worker.model_runner import _get_graph_batch_size

MODELS = ["ai21labs/Jamba-tiny-random"]
Expand Down Expand Up @@ -34,6 +35,34 @@ def test_models(
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])

batched_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)

check_outputs_equal(
outputs_0_lst=for_loop_outputs,
outputs_1_lst=batched_outputs,
name_0="for_loop_vllm",
name_1="batched_vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
Expand All @@ -60,6 +89,60 @@ def test_mamba_cache_cg_padding(
"Could be related to mamba cache not padded correctly")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)

vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum mamba block capacity.
# This could generally happen due to the fact that Jamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
Expand Down
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs():
if seq.is_finished():
continue
Expand Down
32 changes: 20 additions & 12 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (supports_lora,
from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -66,10 +67,10 @@ def _get_quantization_config(


def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
) -> Dict[str, Any]:
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}

Expand All @@ -90,13 +91,19 @@ def _get_model_initialization_kwargs(

extra_kwargs["multimodal_config"] = multimodal_config

if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config

return extra_kwargs


def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module:
def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config)
Expand All @@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, multimodal_config))
model_class, lora_config, multimodal_config,
scheduler_config))


class BaseModelLoader(ABC):
Expand Down Expand Up @@ -266,7 +274,7 @@ def load_model(self, *, model_config: ModelConfig,
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
cache_config, scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
Expand Down Expand Up @@ -302,7 +310,7 @@ def load_model(self, *, model_config: ModelConfig,
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
cache_config, scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
Expand Down
48 changes: 47 additions & 1 deletion vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing_extensions import TypeGuard

from vllm.config import LoRAConfig, MultiModalConfig
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -142,3 +142,49 @@ def _supports_lora(
return isinstance(model, _SupportsLoRAType)

return isinstance(model, SupportsLoRA)


@runtime_checkable
class HasInnerState(Protocol):
"""The interface required for all models that has inner state."""

has_inner_state: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
"""

def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...


@runtime_checkable
class _HasInnerStateType(Protocol):
has_inner_state: ClassVar[Literal[True]]

def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...


@overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
...


@overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
...


def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)

return isinstance(model, HasInnerState)
36 changes: 25 additions & 11 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
Expand All @@ -32,10 +32,12 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import HasInnerState
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -612,7 +614,7 @@ def forward(
return hidden_states


class JambaForCausalLM(nn.Module):
class JambaForCausalLM(nn.Module, HasInnerState):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -640,9 +642,11 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None:
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = JambaModel(config,
cache_config=cache_config,
quant_config=quant_config,
Expand Down Expand Up @@ -689,16 +693,17 @@ def forward(self,
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])

request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids)
(
current_seqlen_agnostic_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size)
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size,
finished_requests_ids)
else:
# CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = (
Expand Down Expand Up @@ -760,10 +765,15 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
return indices_for_current_run

def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
finished_requests_ids: List[str]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
indices_for_current_run = []
for request_id, seqs_id in request_ids_to_seq_ids.items():
if request_id in finished_requests_ids:
# Do not allocate cache for requests that run
# and finish right after
continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
request_id, seqs_id)
## Pad the batch in case of running batch that was not captured via CG
Expand All @@ -787,16 +797,17 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size)
cg_batch_size,
finished_requests_ids)
self.current_indices = indices
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)

for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"],
Expand Down Expand Up @@ -860,9 +871,12 @@ def _prepare_mamba_cache(self):
layers_type = self.config.layers_block_type
mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None

for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
Expand Down

0 comments on commit 9feb0a4

Please sign in to comment.