Skip to content

Commit

Permalink
[ci][bugfix] fix kernel tests (vllm-project#10431)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
youkaichao authored and mfournioux committed Nov 20, 2024
1 parent 079658e commit 8cccd5e
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

if TYPE_CHECKING:
from vllm.config import CompilationConfig, VllmConfig
else:
CompilationConfig = None
VllmConfig = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,23 +47,23 @@ def load_general_plugins():
logger.exception("Failed to load plugin %s", plugin.name)


_compilation_config: Optional[CompilationConfig] = None
_compilation_config: Optional["CompilationConfig"] = None


def set_compilation_config(config: Optional[CompilationConfig]):
def set_compilation_config(config: Optional["CompilationConfig"]):
global _compilation_config
_compilation_config = config


def get_compilation_config() -> Optional[CompilationConfig]:
def get_compilation_config() -> Optional["CompilationConfig"]:
return _compilation_config


_current_vllm_config: Optional[VllmConfig] = None
_current_vllm_config: Optional["VllmConfig"] = None


@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig):
def set_current_vllm_config(vllm_config: "VllmConfig"):
"""
Temporarily set the current VLLM config.
Used during model initialization.
Expand All @@ -87,6 +84,12 @@ def set_current_vllm_config(vllm_config: VllmConfig):
_current_vllm_config = old_vllm_config


def get_current_vllm_config() -> VllmConfig:
assert _current_vllm_config is not None, "Current VLLM config is not set."
def get_current_vllm_config() -> "VllmConfig":
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config

0 comments on commit 8cccd5e

Please sign in to comment.