Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][Model] Jamba - Handle aborted requests, Add tests and fix cleanup bug #6425

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 196 additions & 6 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from tests.models.utils import check_outputs_equal
from vllm.worker.model_runner import _get_graph_batch_size

MODELS = ["ai21labs/Jamba-tiny-random"]
Expand All @@ -21,17 +22,152 @@ def test_models(

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
hf_logprobs_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs=2)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
vllm_logprobs_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs=2)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
_, hf_output_str = hf_outputs[i]
hf_output_ids, _, hf_output_logprobs = hf_logprobs_outputs[i]

_, vllm_output_str = vllm_outputs[i]
vllm_output_ids, _, vllm_output_logprobs = vllm_logprobs_outputs[i]

if hf_output_str != vllm_output_str:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
first_diff_index = [
hf_id == vllm_id
for hf_id, vllm_id in zip(hf_output_ids, vllm_output_ids)
].index(False)
hf_disagreement_logprobs = hf_output_logprobs[first_diff_index]
vllm_disagreement_logprobs = {
k: v.logprob
for k, v in vllm_output_logprobs[first_diff_index].items()
}

assert (hf_output_ids[first_diff_index]
in vllm_disagreement_logprobs), (
f"Test{i}:different outputs\n"
f"HF: {hf_output_str!r}\n"
f"vLLM: {vllm_output_str!r}\n",
f"Disagreement in {first_diff_index}th token. "
f"HF id: {hf_output_ids[first_diff_index]}, "
f"vLLM id: {vllm_output_ids[first_diff_index]})\n",
"HF top token not in vLLM top 2 tokens")

vllm_disagreement_logprobs_values = list(
vllm_disagreement_logprobs.values())
vllm_logprobs_diff = abs(vllm_disagreement_logprobs_values[0] -
vllm_disagreement_logprobs_values[1])
vllm_hf_diff = abs(
hf_disagreement_logprobs[hf_output_ids[first_diff_index]] -
vllm_disagreement_logprobs[hf_output_ids[first_diff_index]])

assert (vllm_logprobs_diff < vllm_hf_diff
or vllm_logprobs_diff < 1e-4), (
f"Test{i}:different outputs\n"
f"HF: {hf_output_str!r}\n"
f"vLLM: {vllm_output_str!r}\n",
f"Disagreement in {first_diff_index}th token. "
f"HF id: {hf_output_ids[first_diff_index]}, "
f"vLLM id: {vllm_output_ids[first_diff_index]})\n",
f"HF top token in vLLM top 2 tokens, "
f"but logprobs diff is too large. "
f"vLLM top 2 logprob diff: {vllm_logprobs_diff}\n",
f"HF to vLLM diff of top HF token: {vllm_hf_diff}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [15])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
# assert dtype == "float"

with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
for_loop_outputs = []
for_loop_logprobs_outputs = []
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])
for_loop_logprobs_outputs.append(
vllm_model.generate_greedy_logprobs([prompt],
max_tokens,
num_logprobs=2)[0])

batched_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
batched_logprobs_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs=2)

for i in range(len(example_prompts)):
_, for_loop_output_str = for_loop_outputs[i]
(for_loop_output_ids, _,
for_loop_output_logprobs) = for_loop_logprobs_outputs[i]

_, batched_output_str = batched_outputs[i]
(batched_output_ids, _,
batched_output_logprobs) = batched_logprobs_outputs[i]

if for_loop_output_str != batched_output_str:
first_diff_index = [
for_loop_id == batched_id for for_loop_id, batched_id in zip(
for_loop_output_ids, batched_output_ids)
].index(False)
for_loop_disagreement_logprobs = {
k: v.logprob
for k, v in for_loop_output_logprobs[first_diff_index].items()
}
batched_disagreement_logprobs = {
k: v.logprob
for k, v in batched_output_logprobs[first_diff_index].items()
}

assert (
for_loop_output_ids[first_diff_index]
in batched_disagreement_logprobs), (
f"Test{i}:different outputs\n"
f"For-loop: {for_loop_output_str!r}\n",
f"Batched: {batched_output_str!r}\n",
f"Disagreement in {first_diff_index}th token. "
f"For-loop id: {for_loop_output_ids[first_diff_index]}, "
f"Batched id: {batched_output_ids[first_diff_index]})\n",
"For-loop top token not in batched top 2 tokens")

batched_disagreement_logprobs_values = list(
batched_disagreement_logprobs.values())
batched_logprobs_diff = abs(
batched_disagreement_logprobs_values[0] -
batched_disagreement_logprobs_values[1])
batched_for_loop_diff = abs(
for_loop_disagreement_logprobs[
for_loop_output_ids[first_diff_index]] -
batched_disagreement_logprobs[
for_loop_output_ids[first_diff_index]])

assert (
batched_logprobs_diff < batched_for_loop_diff
or batched_logprobs_diff < 1e-4), (
f"Test{i}:different outputs\n"
f"For-loop: {for_loop_output_str!r}\n"
f"Batched: {batched_output_str!r}\n",
f"Disagreement in {first_diff_index}th token. "
f"For-loop id: {for_loop_output_ids[first_diff_index]}, "
f"Batched id: {batched_output_ids[first_diff_index]})\n",
f"For-loop top token in batched top 2 tokens, "
f"but logprobs diff is too large. "
f"Batched top 2 logprob diff: {batched_logprobs_diff}\n",
f"For-loop to batched diff of top for-loop token: "
f"{batched_logprobs_diff}")


@pytest.mark.parametrize("model", MODELS)
Expand Down Expand Up @@ -60,6 +196,60 @@ def test_mamba_cache_cg_padding(
"Could be related to mamba cache not padded correctly")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models_preemption_recompute(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)

vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum mamba block capacity.
# This could generally happen due to the fact that Jamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
Expand Down
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs():
if seq.is_finished():
continue
Expand Down
32 changes: 20 additions & 12 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (supports_lora,
from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -66,10 +67,10 @@ def _get_quantization_config(


def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
) -> Dict[str, Any]:
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}

Expand All @@ -90,13 +91,19 @@ def _get_model_initialization_kwargs(

extra_kwargs["multimodal_config"] = multimodal_config

if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config

return extra_kwargs


def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module:
def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config)
Expand All @@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, multimodal_config))
model_class, lora_config, multimodal_config,
scheduler_config))


class BaseModelLoader(ABC):
Expand Down Expand Up @@ -266,7 +274,7 @@ def load_model(self, *, model_config: ModelConfig,
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
cache_config, scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
Expand Down Expand Up @@ -306,7 +314,7 @@ def load_model(self, *, model_config: ModelConfig,
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
cache_config, scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
Expand Down
35 changes: 34 additions & 1 deletion vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing_extensions import TypeGuard

from vllm.config import LoRAConfig, MultiModalConfig
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -142,3 +142,36 @@ def _supports_lora(
return isinstance(model, _SupportsLoRAType)

return isinstance(model, SupportsLoRA)


@runtime_checkable
class HasInnerState(Protocol):
"""The interface required for all models that has inner state."""

has_inner_state: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
"""

def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...


@overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
...


@overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
...


def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
return isinstance(model, HasInnerState)
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading