Skip to content

Commit

Permalink
[Frontend] Warn if user max_model_len is greater than derived `max_…
Browse files Browse the repository at this point in the history
…model_len` (#7080)

Signed-off-by: Jefferson Fialho <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
  • Loading branch information
fialhocoelho and njhill authored Aug 3, 2024
1 parent 44dcb52 commit 825b044
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
19 changes: 13 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
Expand Down Expand Up @@ -1541,15 +1542,21 @@ def _get_and_verify_max_len(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate.")
pass
else:
raise ValueError(
msg = (
f"User-specified max_model_len ({max_model_len}) is greater "
"than the derived max_model_len "
f"({max_len_key}={derived_max_model_len} or model_max_length="
f"than the derived max_model_len ({max_len_key}="
f"{derived_max_model_len} or model_max_length="
f"{model_max_length} in model's config.json). This may lead "
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size.")
"to incorrect model outputs or CUDA errors.")
if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
logger.warning(
"%s Make sure the value is correct and within the "
"model context size.", msg)
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set "
"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
return int(max_model_len)


Expand Down
10 changes: 10 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
VLLM_NO_DEPRECATION_WARNING: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -331,6 +332,15 @@ def get_default_config_root():
# If set, vllm will skip the deprecation warnings.
"VLLM_NO_DEPRECATION_WARNING":
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),

# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than
# the max length derived from the model's config.json.
# To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.
"VLLM_ALLOW_LONG_MAX_MODEL_LEN":
lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")),
}

# end-env-vars-definition
Expand Down

0 comments on commit 825b044

Please sign in to comment.