Skip to content

Commit

Permalink
Merge pull request #4 from luo-cheng2021/luocheng/openvino-model-exec…
Browse files Browse the repository at this point in the history
…utor

Adapt OpenVINO CPU plugin implementation
  • Loading branch information
ilya-lavrenov authored Mar 21, 2024
2 parents 2922b06 + f2dd24e commit 658407a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 9 deletions.
20 changes: 18 additions & 2 deletions vllm/executor/openvino_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.sampling_params import SamplingParams

import openvino as ov
import openvino.properties.hint as hints

logger = init_logger(__name__)

Expand Down Expand Up @@ -59,6 +60,11 @@ def __init__(
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)

if device_config.device.type == "cpu":
if cache_config.block_size != 1:
cache_config.num_cpu_blocks *= cache_config.block_size
cache_config.block_size = 1
print(f"Warning: CPU only support block_size = 1, it's forced to 1, num_cpu_blocks is set to {cache_config.num_cpu_blocks}.")
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks
Expand Down Expand Up @@ -146,7 +152,15 @@ def get_cache_dtype(
# probably, we need to force OpenVINO kv cache data types per device and assert
# if user specified a different value
if cache_dtype == "auto":
cache_dtype = model_config.dtype
if device_config.device.type == "cpu":
core = ov.Core()
inference_precision = core.get_property("CPU", hints.inference_precision)
if inference_precision == ov.Type.bf16:
cache_dtype = torch.bfloat16
else:
cache_dtype = torch.float16
else:
cache_dtype = model_config.dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
ov_cache_dtype = TORCH_DTYPE_TO_OPENVINO_DTYPE[cache_dtype]
Expand Down Expand Up @@ -469,7 +483,9 @@ def _init_worker(self):
self.scheduler_config,
self.device_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
kv_cache_dtype=OpenVINOCacheEngine.get_cache_dtype(self.cache_config.cache_dtype,
self.model_config,
self.device_config)
)
self.driver_worker.init_model()
self.driver_worker.load_model()
Expand Down
26 changes: 19 additions & 7 deletions vllm/model_executor/openvino_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.utils import is_openvino_optimum_intel

import openvino as ov
from openvino import Type


def _flattenize_inputs(inputs):
Expand Down Expand Up @@ -56,7 +57,8 @@ def ov_wrapper(self, *args, **kwargs) -> torch.Tensor:

def patch_stateful_model(
model: ov.Model,
factory):
factory,
kv_cache_dtype: Type):
print('TRANSFORMING OPTIMUM-INTEL MODEL TO vLLM COMPATIBLE FORM')
from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher, AnyInput, Or
from openvino.runtime import opset13
Expand Down Expand Up @@ -128,8 +130,8 @@ def callback(m: Matcher) -> bool:
real_v = mapping[v_current]
hidden_shape = real_q.get_partial_shape()
hidden_dim = hidden_shape[hidden_shape.rank.get_length() - 1].get_length() # TODO: What if it is a dynamic? Need to insert a ShapeOf sub-graph instead
k_parameter = opset13.parameter(shape=[-1, -1, -1, -1, -1], dtype=np.float32)
v_parameter = opset13.parameter(shape=[-1, -1, -1, -1], dtype=np.float32)
k_parameter = opset13.parameter(shape=[-1, -1, -1, -1, -1], dtype=kv_cache_dtype)
v_parameter = opset13.parameter(shape=[-1, -1, -1, -1], dtype=kv_cache_dtype)
kv_parameters.append(k_parameter)
kv_parameters.append(v_parameter)
# TODO: The rank 4 is used in the following code, but it is not guaranteed for all models, adopt to other ranks.
Expand Down Expand Up @@ -274,7 +276,8 @@ def has_parameter(model, name):

def _patch_model_with_openvino(
pt_model: torch.nn.Module,
model_config: ModelConfig):
model_config: ModelConfig,
kv_cache_dtype: Type):
print(' ============= PATCHING MODEL =============')
from vllm.model_executor.layers.attention.attention import Attention
from openvino.frontend.pytorch import ModuleExtension
Expand All @@ -294,7 +297,15 @@ def _patch_model_with_openvino(

# Prepare example inputs

kv_cache_dtype = torch.float32
torch_dtype_maping = {
Type.boolean: torch.bool,
Type.f32: torch.float32,
Type.f16: torch.float16,
Type.bf16: torch.bfloat16,
Type.i32: torch.int32,
Type.i64: torch.int64
}
kv_cache_dtype = torch_dtype_maping[kv_cache_dtype]
num_heads = pt_model.config.num_attention_heads
num_kv_heads = num_heads
head_size = pt_model.config.hidden_size // num_kv_heads
Expand Down Expand Up @@ -423,6 +434,7 @@ def ov_sample(

def get_model(model_config: ModelConfig,
device_config: DeviceConfig,
kv_cache_dtype: Type,
**kwargs) -> torch.nn.Module:
lora_config = kwargs.get("lora_config", None)
if lora_config:
Expand All @@ -443,7 +455,7 @@ def get_model(model_config: ModelConfig,
# Keep factory to destroy it in a particular moment when all other objects referencing custom nodes are destoyed
pt_model.ov_node_factory = NodeFactory()
pt_model.ov_node_factory.add_extension('libuser_ov_extensions.so')
patch_stateful_model(pt_model.model, pt_model.ov_node_factory)
patch_stateful_model(pt_model.model, pt_model.ov_node_factory, kv_cache_dtype)
core = ov.Core()
ov_compiled = core.compile_model(pt_model.model, "CPU")
pt_model._ov_request = ov_compiled.create_infer_request()
Expand All @@ -457,6 +469,6 @@ def get_model(model_config: ModelConfig,
else:
from vllm.model_executor.model_loader import get_model
pt_model = get_model(model_config, device_config, **kwargs)
_patch_model_with_openvino(pt_model, model_config)
_patch_model_with_openvino(pt_model, model_config, kv_cache_dtype)

return pt_model
1 change: 1 addition & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def load_model(self) -> None:
with measure_cuda_memory(self.device) as m:
self.model = get_model(self.model_config,
self.device_config,
kv_cache_dtype=self.kv_cache_dtype,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
Expand Down

0 comments on commit 658407a

Please sign in to comment.