Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3/N] model runner pass the whole config to model #9958

Merged
merged 6 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory(shutdown_ray=True)
get_model_old = get_model

def get_model_patched(*, model_config, device_config, **kwargs):
kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
return get_model_old(model_config=model_config,
device_config=device_config,
**kwargs)
def get_model_patched(**kwargs):
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
max_lora_rank=8)
return get_model_old(**kwargs)

with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
Expand Down
20 changes: 4 additions & 16 deletions vllm/model_executor/model_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
from typing import Optional

from torch import nn

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader)
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture)


def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config)
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
return loader.load_model(vllm_config=vllm_config)


__all__ = [
Expand Down
132 changes: 54 additions & 78 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, PoolerConfig, SchedulerConfig)
from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PoolerConfig, SchedulerConfig, VllmConfig)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
Expand Down Expand Up @@ -150,6 +150,7 @@ def _get_model_initialization_kwargs(


def build_model(model_class: Type[nn.Module],
vllm_config: VllmConfig,
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
Expand All @@ -166,23 +167,29 @@ def build_model(model_class: Type[nn.Module],
if prefix:
extra_kwargs["prefix"] = prefix

# TODO: unify all the module initialization code
# to only take the `VllmConfig` object as input
from vllm.plugins import set_vllm_config
set_vllm_config(vllm_config)

return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)


def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
load_config = vllm_config.load_config
model_class, _ = get_model_architecture(model_config)

return build_model(
model_class,
vllm_config,
model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(model_config, load_config),
Expand All @@ -205,12 +212,7 @@ def download_model(self, model_config: ModelConfig) -> None:
raise NotImplementedError

@abstractmethod
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
"""Load a model with the given configurations."""
raise NotImplementedError

Expand Down Expand Up @@ -396,18 +398,14 @@ def download_model(self, model_config: ModelConfig) -> None:
model_config.revision,
fall_back_to_pt=True)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config

target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model = _initialize_model(vllm_config=vllm_config)

model.load_weights(self._get_all_weights(model_config, model))

Expand Down Expand Up @@ -436,17 +434,12 @@ def __init__(self, load_config: LoadConfig):
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model = _initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
Expand Down Expand Up @@ -488,10 +481,7 @@ def _get_weights_iterator(

def _load_model_serialized_cpu(
self,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.

Expand All @@ -500,26 +490,30 @@ def _load_model_serialized_cpu(
default HuggingFace loading, but will be slower than loading a
vLLM-tensorized model.
"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)

model.load_weights(self._get_weights_iterator())
return model.eval()

def _load_model_serialized(
self,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.

Expects a vLLM-tensorized model. See the
examples/tensorize_vllm_model.py example script
for serializing vLLM models."""

device_config = vllm_config.device_config
model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
cache_config = vllm_config.cache_config

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
Expand All @@ -544,12 +538,9 @@ def download_model(self, model_config: ModelConfig) -> None:
with self.tensorizer_config.open_stream():
pass

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)

if parallel_config.tensor_parallel_size > 1:
Expand All @@ -559,10 +550,8 @@ def load_model(self, *, model_config: ModelConfig,
% get_tensor_model_parallel_rank()

if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config, cache_config)
return self._load_model_serialized_cpu(model_config, device_config,
lora_config, cache_config)
return self._load_model_serialized(vllm_config=vllm_config)
return self._load_model_serialized_cpu(vllm_config=vllm_config)

@staticmethod
def save_model(
Expand Down Expand Up @@ -648,12 +637,9 @@ def _prepare_weights(self, model_name_or_path: str,
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
from safetensors.torch import safe_open

from vllm.distributed import get_tensor_model_parallel_rank
Expand All @@ -663,8 +649,7 @@ def load_model(self, *, model_config: ModelConfig,

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
Expand Down Expand Up @@ -1157,16 +1142,12 @@ def _load_weights(self, model_config: ModelConfig,
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)

self._load_weights(model_config, model)

Expand Down Expand Up @@ -1235,13 +1216,9 @@ def _get_weights_iterator(
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:

def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
Expand All @@ -1251,8 +1228,7 @@ def load_model(self, *, model_config: ModelConfig,

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
return model
Expand Down
22 changes: 20 additions & 2 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import logging
from typing import Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Union

import vllm.envs as envs
from vllm.compilation.config import CompilationConfig

if TYPE_CHECKING:
from vllm.compilation.config import CompilationConfig
from vllm.config import VllmConfig
else:
CompilationConfig = None
VllmConfig = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]):

def get_compilation_config() -> Optional[CompilationConfig]:
return _compilation_config


_vllm_config: Optional[VllmConfig] = None


def set_vllm_config(config: Optional[VllmConfig]):
global _vllm_config
_vllm_config = config


def get_vllm_config() -> Optional[VllmConfig]:
return _vllm_config
8 changes: 1 addition & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,7 @@ def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)

self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
Expand Down
8 changes: 1 addition & 7 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,7 @@ def model_is_mrope(self) -> bool:
return uses_mrope(self.model_config.hf_config)

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)

def make_model_input_from_broadcasted_tensor_dict(
self,
Expand Down
8 changes: 1 addition & 7 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,13 +1051,7 @@ def __init__(
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)

self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
Expand Down
Loading