Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/N] pass the complete config from engine to executor #9933

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def from_engine_args(

# Create the async LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
Expand Down
49 changes: 20 additions & 29 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
from typing_extensions import TypeIs, TypeVar

import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -222,24 +219,29 @@ def validate_outputs(

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
vllm_config: EngineConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:

model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)

logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
Expand Down Expand Up @@ -340,18 +342,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
self.input_processor = input_registry.create_input_processor(
model_config)

self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
self.model_executor = executor_class(vllm_config=vllm_config, )

if self.model_config.task != "embedding":
self._initialize_kv_caches()
Expand Down Expand Up @@ -582,7 +573,7 @@ def from_engine_args(
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
Expand Down
7 changes: 1 addition & 6 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import zmq

from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
Expand All @@ -30,9 +28,6 @@
else:
from vllm.engine.llm_engine import LLMEngine

CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 10000
Expand Down Expand Up @@ -130,7 +125,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,

return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
Expand Down
37 changes: 13 additions & 24 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import EngineConfig
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand All @@ -23,27 +20,19 @@ class ExecutorBase(ABC):

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
vllm_config: EngineConfig,
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self._init_executor()

@abstractmethod
Expand Down
61 changes: 20 additions & 41 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
Union)

from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
Expand Down Expand Up @@ -35,24 +32,29 @@ class LLMEngine:

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
vllm_config: EngineConfig,
executor_class: Type[GPUExecutor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:

model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
)

# Override the configs for V1.
# FIXME
if usage_context == UsageContext.LLM_CLASS:
Expand Down Expand Up @@ -112,18 +114,6 @@ def __init__(
model_config.mm_processor_kwargs,
)

self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats

assert not self.model_config.skip_tokenizer_init
Expand Down Expand Up @@ -154,18 +144,7 @@ def __init__(
# Request id -> RequestOutput
self.request_outputs: Dict[str, RequestOutput] = {}

self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
self.model_executor = executor_class(vllm_config=vllm_config)
assert self.model_config.task != "embedding"
self._initialize_kv_caches()

Expand Down Expand Up @@ -203,7 +182,7 @@ def from_engine_args(
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
Expand Down