Skip to content

Commit

Permalink
[Hardware][Intel] OpenVINO vLLM backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Jun 10, 2024
1 parent 6b29d6f commit 178e52d
Show file tree
Hide file tree
Showing 22 changed files with 1,358 additions and 20 deletions.
14 changes: 14 additions & 0 deletions .buildkite/run-openvino-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This script build the OpenVINO docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex

# Try building the docker image
docker build -t openvino-test -f Dockerfile.openvino .

# Setup cleanup
remove_docker_container() { docker rm -f openvino-test || true; }
trap remove_docker_container EXIT
remove_docker_container

# Run the image and launch offline inference
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py
4 changes: 4 additions & 0 deletions .buildkite/test-template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ steps:
queue: intel
command: bash .buildkite/run-cpu-test.sh

- label: "OpenVINO Test"
depends_on: ~
command: bash .buildkite/run-openvino-test.sh

{% for step in steps %}
- label: "{{ step.label }}"
agents:
Expand Down
26 changes: 26 additions & 0 deletions Dockerfile.openvino
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server.

FROM ubuntu:22.04 AS dev

RUN apt-get update -y && \
apt-get install -y python3-pip git
WORKDIR /workspace

# copy requirements
COPY requirements-build.txt /workspace/vllm/
COPY requirements-common.txt /workspace/vllm/
COPY requirements-openvino.txt /workspace/vllm/

COPY vllm/ /workspace/vllm/vllm
COPY setup.py /workspace/vllm/

# install build requirements
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
# build vLLM with OpenVINO backend
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/

COPY examples/ /workspace/vllm/examples
COPY benchmarks/ /workspace/vllm/benchmarks

CMD ["/bin/bash"]
7 changes: 4 additions & 3 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
default="auto",
choices=["auto", "cuda", "cpu", "openvino"],
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.')
parser.add_argument('--block-size',
type=int,
default=16,
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,10 @@ def main(args: argparse.Namespace):
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
default="auto",
choices=["auto", "cuda", "cpu", "openvino"],
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
Expand Down
95 changes: 95 additions & 0 deletions docs/source/getting_started/openvino-installation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
.. _installation_openvino:

Installation with OpenVINO
========================

vLLM powered by OpenVINO supports all LLM models from [vLLM supported models list](../dev/models/supported_models.rst) and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:

- Prefix caching (``--enable-prefix-caching``)
- Chunked prefill (``--enable-chunked-prefill``)

Table of contents:

#. :ref:`Requirements <openvino_backend_requirements>`
#. :ref:`Quick start using Dockerfile <openvino_backend_quick_start_dockerfile>`
#. :ref:`Build from source <binstall_openvino_backend_from_source>`
#. :ref:`Performance tips <openvino_backend_performance_tips>`
#. :ref:`Limitations <openvino_backend_limitations>`

.. _openvino_backend_requirements:

Requirements
------------

* OS: Linux
* Instruction set architecture (ISA) requirement: at least AVX2.

.. _openvino_backend_quick_start_dockerfile:

Quick start using Dockerfile
----------------------------

.. code-block:: console
$ docker build -f Dockerfile.openvino -t vllm-openvino-env .
$ docker run -it --rm vllm-openvino-env
.. _install_openvino_backend_from_source:

Install from source
-----------------

- First, install Python. For example, on Ubuntu 22.04, you can run:

.. code-block:: console
$ sudo apt-get update -y
$ sudo apt-get install python3
- Second, install prerequisites vLLM OpenVINO backend installation:

.. code-block:: console
$ pip install --upgrade pip
$ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Finally, install vLLM with OpenVINO backend:

.. code-block:: console
$ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python install -v .
.. _openvino_backend_performance_tips:

Performance tips
-----------------

vLLM OpenVINO backend uses the following environment variables to control behavior:

- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.

- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off.

To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)

OpenVINO best known configuration is:

.. code-block:: console
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
.. _openvino_backend_limitations:

Limitations
-----------------

- LoRA serving is not supported.

- Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration.

- Tensor and pipeline parallelism are not currently enabled in vLLM integration.

- Speculative sampling is not tested within vLLM integration.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Documentation
getting_started/installation
getting_started/amd-installation
getting_started/neuron-installation
getting_started/openvino-installation
getting_started/cpu-installation
getting_started/quickstart
getting_started/examples/examples_index
Expand Down
9 changes: 9 additions & 0 deletions requirements-openvino.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Common dependencies
-r requirements-common.txt

# OpenVINO dependencies
torch >= 2.1.2
openvino ~= 2024.3.0.dev
optimum-intel[openvino] >= 1.17.2

triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
13 changes: 11 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def _is_cpu() -> bool:
return VLLM_TARGET_DEVICE == "cpu"


def _is_openvino() -> bool:
return VLLM_TARGET_DEVICE == "openvino"


def _install_punica() -> bool:
return envs.VLLM_INSTALL_PUNICA_KERNELS

Expand Down Expand Up @@ -325,6 +329,8 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
elif _is_openvino():
version += "+openvino"
elif _is_cpu():
version += "+cpu"
else:
Expand Down Expand Up @@ -372,11 +378,14 @@ def _read_requirements(filename: str) -> List[str]:
requirements = _read_requirements("requirements-rocm.txt")
elif _is_neuron():
requirements = _read_requirements("requirements-neuron.txt")
elif _is_openvino():
requirements = _read_requirements("requirements-openvino.txt")
elif _is_cpu():
requirements = _read_requirements("requirements-cpu.txt")
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
"Unsupported platform, please use CUDA, ROCm, Neuron, "
"OpenVINO, or CPU.")
return requirements


Expand All @@ -385,7 +394,7 @@ def _read_requirements(filename: str) -> List[str]:
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))

if not _is_neuron():
if not (_is_neuron() or _is_openvino()):
ext_modules.append(CMakeExtension(name="vllm._C"))

if _install_punica():
Expand Down
75 changes: 75 additions & 0 deletions vllm/attention/backends/openvino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple

import openvino as ov
import torch

from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)


class OpenVINOAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "openvino"

@staticmethod
def get_impl_cls():
# OpenVINO implements PagedAttention as part of the Optimum
# exported model
raise NotImplementedError

@staticmethod
def make_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs)

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, num_kv_heads, block_size, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: ov.Tensor,
dst_kv_cache: ov.Tensor,
src_to_dst: torch.Tensor,
) -> None:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise NotImplementedError

@staticmethod
def copy_blocks(
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
src_to_dists: List[Tuple[int, int]],
) -> None:
for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches:
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]


@dataclass
class OpenVINOAttentionMetadata(AttentionMetadata):
"""Metadata for OpenVINOAttentionBackend.
"""
past_lens: torch.Tensor
subsequence_begins: torch.Tensor
block_indices: torch.Tensor
block_indices_begins: torch.Tensor
max_context_len: torch.Tensor

@property
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
# OpenVINO uses its own metadata format
raise NotImplementedError

@property
def decode_metadata(self) -> Optional["AttentionMetadata"]:
# OpenVINO uses its own metadata format
raise NotImplementedError
10 changes: 9 additions & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip
from vllm.utils import is_cpu, is_hip, is_openvino

logger = init_logger(__name__)

Expand All @@ -17,6 +17,7 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()


Expand Down Expand Up @@ -60,6 +61,10 @@ def get_attn_backend(
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.OPENVINO:
logger.info("Using OpenVINO Attention backend.")
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
return OpenVINOAttentionBackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
Expand Down Expand Up @@ -100,6 +105,9 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

if is_openvino():
return _Backend.OPENVINO

if is_hip():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
Expand Down
6 changes: 4 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_openvino

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -730,6 +730,8 @@ def __init__(self, device: str = "auto") -> None:
# Automated device type detection
if is_neuron():
self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"
elif is_cpu():
self.device_type = "cpu"
else:
Expand All @@ -741,7 +743,7 @@ def __init__(self, device: str = "auto") -> None:
self.device_type = device

# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
if self.device_type in ["neuron", "openvino"]:
self.device = torch.device("cpu")
else:
# Set device with device type
Expand Down
11 changes: 6 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,11 +494,12 @@ def add_cli_args(
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron", "cpu"],
help='Device type for vLLM execution.')
parser.add_argument(
"--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron", "openvino", "cpu"],
help='Device type for vLLM execution.')

# Related to Vision-language models such as llava
parser = EngineArgs.add_cli_args_for_vlm(parser)
Expand Down
Loading

0 comments on commit 178e52d

Please sign in to comment.