Skip to content

Commit

Permalink
[3/N] model runner pass the whole config to model (vllm-project#9958)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
youkaichao authored and sumitd2 committed Nov 14, 2024
1 parent 712f3b7 commit eee723d
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 140 deletions.
9 changes: 4 additions & 5 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory(shutdown_ray=True)
get_model_old = get_model

def get_model_patched(*, model_config, device_config, **kwargs):
kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
return get_model_old(model_config=model_config,
device_config=device_config,
**kwargs)
def get_model_patched(**kwargs):
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
max_lora_rank=8)
return get_model_old(**kwargs)

with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
Expand Down
20 changes: 4 additions & 16 deletions vllm/model_executor/model_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
from typing import Optional

from torch import nn

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader)
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture)


def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config)
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
return loader.load_model(vllm_config=vllm_config)


__all__ = [
Expand Down
132 changes: 54 additions & 78 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, PoolerConfig, SchedulerConfig)
from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PoolerConfig, SchedulerConfig, VllmConfig)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
Expand Down Expand Up @@ -150,6 +150,7 @@ def _get_model_initialization_kwargs(


def build_model(model_class: Type[nn.Module],
vllm_config: VllmConfig,
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
Expand All @@ -166,23 +167,29 @@ def build_model(model_class: Type[nn.Module],
if prefix:
extra_kwargs["prefix"] = prefix

# TODO: unify all the module initialization code
# to only take the `VllmConfig` object as input
from vllm.plugins import set_vllm_config
set_vllm_config(vllm_config)

return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)


def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
load_config = vllm_config.load_config
model_class, _ = get_model_architecture(model_config)

return build_model(
model_class,
vllm_config,
model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(model_config, load_config),
Expand All @@ -205,12 +212,7 @@ def download_model(self, model_config: ModelConfig) -> None:
raise NotImplementedError

@abstractmethod
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
"""Load a model with the given configurations."""
raise NotImplementedError

Expand Down Expand Up @@ -396,18 +398,14 @@ def download_model(self, model_config: ModelConfig) -> None:
model_config.revision,
fall_back_to_pt=True)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config

target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model = _initialize_model(vllm_config=vllm_config)

model.load_weights(self._get_all_weights(model_config, model))

Expand Down Expand Up @@ -436,17 +434,12 @@ def __init__(self, load_config: LoadConfig):
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model = _initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
Expand Down Expand Up @@ -488,10 +481,7 @@ def _get_weights_iterator(

def _load_model_serialized_cpu(
self,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
Expand All @@ -500,26 +490,30 @@ def _load_model_serialized_cpu(
default HuggingFace loading, but will be slower than loading a
vLLM-tensorized model.
"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)

model.load_weights(self._get_weights_iterator())
return model.eval()

def _load_model_serialized(
self,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
examples/tensorize_vllm_model.py example script
for serializing vLLM models."""

device_config = vllm_config.device_config
model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
cache_config = vllm_config.cache_config

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
Expand All @@ -544,12 +538,9 @@ def download_model(self, model_config: ModelConfig) -> None:
with self.tensorizer_config.open_stream():
pass

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)

if parallel_config.tensor_parallel_size > 1:
Expand All @@ -559,10 +550,8 @@ def load_model(self, *, model_config: ModelConfig,
% get_tensor_model_parallel_rank()

if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config, cache_config)
return self._load_model_serialized_cpu(model_config, device_config,
lora_config, cache_config)
return self._load_model_serialized(vllm_config=vllm_config)
return self._load_model_serialized_cpu(vllm_config=vllm_config)

@staticmethod
def save_model(
Expand Down Expand Up @@ -648,12 +637,9 @@ def _prepare_weights(self, model_name_or_path: str,
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
from safetensors.torch import safe_open

from vllm.distributed import get_tensor_model_parallel_rank
Expand All @@ -663,8 +649,7 @@ def load_model(self, *, model_config: ModelConfig,

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
Expand Down Expand Up @@ -1157,16 +1142,12 @@ def _load_weights(self, model_config: ModelConfig,
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)

self._load_weights(model_config, model)

Expand Down Expand Up @@ -1235,13 +1216,9 @@ def _get_weights_iterator(
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:

def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
Expand All @@ -1251,8 +1228,7 @@ def load_model(self, *, model_config: ModelConfig,

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
return model
Expand Down
22 changes: 20 additions & 2 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import logging
from typing import Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Union

import vllm.envs as envs
from vllm.compilation.config import CompilationConfig

if TYPE_CHECKING:
from vllm.compilation.config import CompilationConfig
from vllm.config import VllmConfig
else:
CompilationConfig = None
VllmConfig = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]):

def get_compilation_config() -> Optional[CompilationConfig]:
return _compilation_config


_vllm_config: Optional[VllmConfig] = None


def set_vllm_config(config: Optional[VllmConfig]):
global _vllm_config
_vllm_config = config


def get_vllm_config() -> Optional[VllmConfig]:
return _vllm_config
8 changes: 1 addition & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,7 @@ def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)

self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
Expand Down
8 changes: 1 addition & 7 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,7 @@ def model_is_mrope(self) -> bool:
return uses_mrope(self.model_config.hf_config)

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)

def make_model_input_from_broadcasted_tensor_dict(
self,
Expand Down
8 changes: 1 addition & 7 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,13 +1051,7 @@ def __init__(
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)

self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
Expand Down
Loading

0 comments on commit eee723d

Please sign in to comment.