Skip to content

Commit

Permalink
openvino
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 1, 2024
1 parent 2a8678f commit d076d08
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 48 deletions.
33 changes: 13 additions & 20 deletions vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import EngineConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerBase

logger = init_logger(__name__)

Expand All @@ -38,33 +37,21 @@ def empty(cls, device):
multi_modal_kwargs={})


class OpenVINOModelRunner:
class OpenVINOModelRunner(ModelRunnerBase):

def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
vllm_config: EngineConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.ov_core = ov_core
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.load_config = load_config
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
cache_config = self.cache_config
model_config = self.model_config
self.is_driver_worker = is_driver_worker

self.device = self.device_config.device
Expand Down Expand Up @@ -350,3 +337,9 @@ def execute_model(
sampling_metadata=sampling_metadata,
)
return output

def prepare_model_input(self, *args, **kwargs):
raise NotImplementedError

def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs):
raise NotImplementedError
34 changes: 6 additions & 28 deletions vllm/worker/openvino_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, ModelConfig,
ParallelConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
Expand All @@ -22,7 +21,7 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

logger = init_logger(__name__)

Expand Down Expand Up @@ -212,33 +211,19 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: EngineConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False,
) -> None:
self.ov_core = ov_core
self.model_config = model_config
self.parallel_config = parallel_config
WorkerBase.__init__(self, vllm_config)
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
Expand All @@ -250,14 +235,7 @@ def __init__(
init_cached_hf_modules()
self.model_runner = OpenVINOModelRunner(
self.ov_core,
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
vllm_config=self.vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
)
Expand Down

0 comments on commit d076d08

Please sign in to comment.