diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3f02175b0439f..54ba60229325e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -15,8 +15,6 @@ from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op -CUSTOM_OPS_REGISTERED = False - class Attention(nn.Module): """Attention layer. @@ -131,7 +129,6 @@ def forward( attn_metadata: AttentionMetadata, attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: - _register_custom_ops() if self.use_direct_call: return self.impl.forward(query, @@ -301,12 +298,8 @@ def unified_attention_with_output_fake( return -def _register_custom_ops(): +def register_custom_ops(): """Register custom ops for attention.""" - global CUSTOM_OPS_REGISTERED - if CUSTOM_OPS_REGISTERED: - return - direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, @@ -321,4 +314,3 @@ def _register_custom_ops(): fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) - CUSTOM_OPS_REGISTERED = True diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 3b394b9661716..c4c74dfed9a9f 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -2,28 +2,6 @@ from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform from .registry import PlatformRegistry, detect_current_platform -_current_platform: Platform = UnspecifiedPlatform() - - -def initialize_current_platform(): - """Initialize the current platform. This function is called when loading - the vllm plugin.""" - global _current_platform - # Get the current platform from the registry first. If the current platform - # is not set, try to detect the current platform. - if PlatformRegistry.current_platform is not None: - _current_platform = PlatformRegistry.get_current_platform_cls() - else: - _current_platform = detect_current_platform() - - -def update_current_platform(device_name: str): - """Update the current platform. This function is used by users to set the - current platform by hand.""" - global _current_platform - PlatformRegistry.set_current_platform(device_name) - _current_platform = PlatformRegistry.get_current_platform_cls() - class CurrentPlatform: """A wrapper that provides an interface to the current platform. @@ -37,17 +15,38 @@ class CurrentPlatform: """ def __init__(self): - self.platform = _current_platform - - def _refresh_current_platform(self): - """Refresh the current platform dynamically.""" - global _current_platform - if _current_platform is not self.platform: - self.platform = _current_platform + self.platform = UnspecifiedPlatform() + + def initialize_current_platform(self): + """Initialize the current platform. This function is called when loading + the vllm plugin.""" + # Get the current platform from the registry first. If the current + # platform is not set, try to detect the current platform. + if PlatformRegistry.current_platform is not None: + self.platform = PlatformRegistry.get_current_platform_cls() + else: + self.platform = detect_current_platform() + + # Register custom ops for the current platform. + from vllm.attention.layer import register_custom_ops + register_custom_ops() + + def update_current_platform(self, device_name: str): + """Update the current platform. This function is used by users to set + the current platform by hand.""" + PlatformRegistry.set_current_platform(device_name) + self.platform = PlatformRegistry.get_current_platform_cls() def __getattr__(self, name): - """Go pass to the current platform.""" - self._refresh_current_platform() + """Get the attribute. If the attribute is not found, go pass to the + current platform.""" + if name is 'platform': + return self.platform + if name is 'initialize_current_platform': + return self.initialize_current_platform + if name is 'update_current_platform': + return self.update_current_platform + # Go pass to the current platform. return getattr(self.platform, name) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 0336b69a78376..051306a3812dc 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -4,7 +4,7 @@ import torch import vllm.envs as envs -from vllm.platforms import current_platform, initialize_current_platform +from vllm.platforms import current_platform logger = logging.getLogger(__name__) @@ -52,7 +52,7 @@ def load_general_plugins(): plugin.name) # initialize current platform should be called after all plugins are # loaded. - initialize_current_platform() + current_platform.initialize_current_platform() plugins_loaded = True