Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mistral interleaved attn #9414

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,20 @@ def __init__(self,
if self.enforce_eager is None:
self.enforce_eager = False

if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2"]))

if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

print_warning_once(
"Gemma 2 uses sliding window attention for every odd layer, "
f"{self.hf_text_config.model_type} has interleaved attention, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({self.hf_text_config.sliding_window}).")
f"({sliding_window_len_min}).")
self.disable_sliding_window = True

self.max_model_len = _get_and_verify_max_len(
Expand Down Expand Up @@ -422,7 +428,8 @@ def verify_with_parallel_config(
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False

def get_hf_config_sliding_window(self) -> Optional[int]:
def get_hf_config_sliding_window(
self) -> Union[Optional[int], List[Optional[int]]]:
"""Get the sliding window size, or None if disabled."""

# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
Expand All @@ -433,7 +440,7 @@ def get_hf_config_sliding_window(self) -> Optional[int]:
return None
return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[int]:
def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
Expand Down Expand Up @@ -1680,7 +1687,7 @@ def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[int],
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
spec_target_max_model_len: Optional[int] = None,
) -> int:
"""Get and verify the model's maximum length."""
Expand Down Expand Up @@ -1713,9 +1720,12 @@ def _get_and_verify_max_len(
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:

sliding_window_len_min = get_min_sliding_window(sliding_window_len)
max_len_key = "sliding_window" \
if sliding_window_len < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len, sliding_window_len)
if sliding_window_len_min < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len,
sliding_window_len_min)

# If none of the keys were found in the config, use a default and
# log a warning.
Expand Down Expand Up @@ -1803,6 +1813,14 @@ def _get_and_verify_max_len(
return int(max_model_len)


def get_min_sliding_window(
sliding_window: Union[int, List[Optional[int]]]) -> int:
if isinstance(sliding_window, list):
return min(s for s in sliding_window if s is not None)

return sliding_window


def get_served_model_name(model: str,
served_model_name: Optional[Union[str, List[str]]]):
"""
Expand Down