Skip to content

Commit

Permalink
[platforms] absorb worker cls difference into platforms folder (vllm-…
Browse files Browse the repository at this point in the history
…project#10555)

Signed-off-by: youkaichao <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
  • Loading branch information
youkaichao and njhill authored Nov 22, 2024
1 parent 446c780 commit a111d01
Show file tree
Hide file tree
Showing 21 changed files with 273 additions and 283 deletions.
238 changes: 117 additions & 121 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,56 +926,56 @@ def _verify_load_format(self) -> None:
f"{rocm_supported_load_format}")


@dataclass
class ParallelConfig:
"""Configuration for the distributed execution.
"""Configuration for the distributed execution."""

Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Deprecated, use distributed_executor_backend instead.
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
tokenizer_pool_config: Config for the tokenizer pool.
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
placement_group: ray distributed model workers placement group.
distributed_executor_backend: Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
of pipeline_parallel_size and tensor_parallel_size is less than
or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, this will default
to "ray" if Ray is installed and fail otherwise. Note that tpu
and hpu only support Ray for distributed inference.
"""
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
tensor_parallel_size: int = 1 # Number of tensor parallel groups.

def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: Optional[bool] = None,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
distributed_executor_backend: Optional[Union[
str, Type["ExecutorBase"]]] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.distributed_executor_backend = distributed_executor_backend
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group
self.world_size = pipeline_parallel_size * self.tensor_parallel_size

if worker_use_ray:
# Deprecated, use distributed_executor_backend instead.
worker_use_ray: Optional[bool] = None

# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
# parallel and large models.
max_parallel_loading_workers: Optional[int] = None

# Disable the custom all-reduce kernel and fall back to NCCL.
disable_custom_all_reduce: bool = False

# Config for the tokenizer pool. If None, will use synchronous tokenization.
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None

# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
ray_workers_use_nsight: bool = False

# ray distributed model workers placement group.
placement_group: Optional["PlacementGroup"] = None

# Backend to use for distributed model
# workers, either "ray" or "mp" (multiprocessing). If the product
# of pipeline_parallel_size and tensor_parallel_size is less than
# or equal to the number of GPUs available, "mp" will be used to
# keep processing on a single host. Otherwise, this will default
# to "ray" if Ray is installed and fail otherwise. Note that tpu
# and hpu only support Ray for distributed inference.
distributed_executor_backend: Optional[Union[str,
Type["ExecutorBase"]]] = None

# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
worker_cls: str = "auto"

world_size: int = field(init=False)

rank: int = 0

def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

if self.worker_use_ray:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
elif not self.use_ray:
Expand Down Expand Up @@ -1026,7 +1026,6 @@ def __init__(
backend)

self._verify_args()
self.rank: int = 0

@property
def use_ray(self) -> bool:
Expand Down Expand Up @@ -1059,100 +1058,97 @@ def _verify_args(self) -> None:
"run with Ray.")


@dataclass
class SchedulerConfig:
"""Scheduler configuration.
"""Scheduler configuration."""

Args:
task: The task to use the model for.
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
num_lookahead_slots: The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
accepted.
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead.
send_delta_data: Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1
policy: The scheduling policy to use. "fcfs" (default) or "priority".
"""
task: str = "generate" # The task to use the model for.

# Maximum number of tokens to be processed in a single iteration.
max_num_batched_tokens: int = field(default=None) # type: ignore

# Maximum number of sequences to be processed in a single iteration.
max_num_seqs: int = 128

# Maximum length of a sequence (including prompt and generated text).
max_model_len: int = 8192

# The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be
# accepted.
num_lookahead_slots: int = 0

# Apply a delay (of delay factor multiplied by previous
# prompt latency) before scheduling next prompt.
delay_factor: float = 0.0

# If True, prefill requests can be chunked based
# on the remaining max_num_batched_tokens.
enable_chunked_prefill: bool = False

is_multimodal_model: bool = False

def __init__(self,
task: _Task,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1,
multi_step_stream_outputs: bool = False,
send_delta_data: bool = False,
policy: str = "fcfs") -> None:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
if num_scheduler_steps > 1:
# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not currently supported. In
# such a case, we use swapping instead.
preemption_mode: Optional[str] = None

num_scheduler_steps: int = 1

multi_step_stream_outputs: bool = False

# Private API. If used, scheduler sends delta data to
# workers instead of an entire data. It should be enabled only
# when SPMD worker architecture is enabled. I.e.,
# VLLM_USE_RAY_SPMD_WORKER=1
send_delta_data: bool = False

# The scheduling policy to use. "fcfs" (default) or "priority".
policy: str = "fcfs"

chunked_prefill_enabled: bool = field(init=False)

def __post_init__(self) -> None:
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
if self.num_scheduler_steps > 1:
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
# for now. Have max_num_batched_tokens set to max_model_len
# so we don't reject sequences on account of a short
# max_num_batched_tokens.
max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_batched_tokens = max(self.max_model_len, 2048)
else:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
max_num_batched_tokens = 512
self.max_num_batched_tokens = 512
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_batched_tokens = max(self.max_model_len, 2048)

if task == "embedding":
if self.task == "embedding":
# For embedding, choose specific value for higher throughput
max_num_batched_tokens = max(
max_num_batched_tokens,
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if is_multimodal_model:
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
max_num_batched_tokens = max(
max_num_batched_tokens,
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

self.max_num_batched_tokens = max_num_batched_tokens

if enable_chunked_prefill:
if self.enable_chunked_prefill:
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
self.max_num_batched_tokens)

self.task: Final = task
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
self.send_delta_data = send_delta_data
self.policy = policy
self.chunked_prefill_enabled = self.enable_chunked_prefill
self._verify_args()

def _verify_args(self) -> None:
Expand Down Expand Up @@ -2293,10 +2289,10 @@ class VllmConfig:

model_config: ModelConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
scheduler_config: SchedulerConfig = field(default=None,
init=True) # type: ignore
parallel_config: ParallelConfig = field(default_factory=ParallelConfig,
init=True)
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig,
init=True)
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
Expand Down
11 changes: 10 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class EngineArgs:
override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto"

def __post_init__(self):
if not self.tokenizer:
Expand Down Expand Up @@ -887,6 +888,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'compilers, using -O without space is also '
'supported. -O3 is equivalent to -O 3.')

parser.add_argument(
'--worker-cls',
type=str,
default="auto",
help='The worker class to use for distributed execution.')

return parser

@classmethod
Expand Down Expand Up @@ -999,7 +1006,9 @@ def create_engine_config(self) -> VllmConfig:
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend)
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
)

max_model_len = model_config.max_model_len
use_long_context = max_model_len > 32768
Expand Down
7 changes: 1 addition & 6 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,8 @@ def _create_worker(
local_rank: int = 0,
rank: int = 0,
):
worker_module_name = "vllm.worker.cpu_worker"
worker_class_name = "CPUWorker"

wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)

assert self.distributed_init_method is not None

Expand Down
Loading

0 comments on commit a111d01

Please sign in to comment.