Skip to content

Commit

Permalink
move functions to config.py
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 25, 2024
1 parent 6581378 commit e3ff01e
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 73 deletions.
4 changes: 2 additions & 2 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.utils import direct_register_custom_op

global_counter = 0
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
Expand Down
3 changes: 1 addition & 2 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config

# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
Expand Down
3 changes: 1 addition & 2 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import pytest

from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config


# Registered subclass for test
Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import direct_register_custom_op


Expand Down
3 changes: 1 addition & 2 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

import vllm.envs as envs
from vllm.config import CompilationLevel
from vllm.config import CompilationLevel, get_current_vllm_config


class TorchCompileWrapperWithCustomDispatcher:
Expand All @@ -32,7 +32,6 @@ def __init__(self,
# default compilation settings
# compiling the forward method

from vllm.plugins import get_current_vllm_config
backend = get_current_vllm_config(
).compilation_config.init_backend()

Expand Down
51 changes: 51 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import json
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
Expand Down Expand Up @@ -2436,3 +2437,53 @@ def __str__(self):
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)


_current_vllm_config: Optional[VllmConfig] = None


@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
finally:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported.",
vllm_config.model_config.model)
_current_vllm_config = old_vllm_config


def get_current_vllm_config() -> VllmConfig:
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config
2 changes: 1 addition & 1 deletion vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch.nn as nn

from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import print_warning_once

logger = init_logger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
VllmConfig)
VllmConfig, set_current_vllm_config)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
Expand All @@ -47,7 +47,6 @@
safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
from vllm.utils import is_pin_memory_available


Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/model_loader/tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig
from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.plugins import set_current_vllm_config
from vllm.utils import FlexibleArgumentParser

tensorizer_error_msg = None
Expand Down
56 changes: 0 additions & 56 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional

import torch

import vllm.envs as envs

if TYPE_CHECKING:
from vllm.config import VllmConfig

logger = logging.getLogger(__name__)

# make sure one process only loads plugins once
Expand Down Expand Up @@ -64,54 +59,3 @@ def load_general_plugins():
logger.info("plugin %s loaded.", plugin.name)
except Exception:
logger.exception("Failed to load plugin %s", plugin.name)


_current_vllm_config: Optional["VllmConfig"] = None


@contextmanager
def set_current_vllm_config(vllm_config: "VllmConfig"):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationLevel
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
finally:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported.",
vllm_config.model_config.model)
_current_vllm_config = old_vllm_config


def get_current_vllm_config() -> "VllmConfig":
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config

0 comments on commit e3ff01e

Please sign in to comment.