Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vllm inference plugin #2967

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions plugins/flytekit-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,66 @@ def model_serving(questions: list[str], gguf: FlyteFile) -> list[str]:

return responses
```

## vLLM

The vLLM plugin allows you to serve an LLM hosted on HuggingFace.

```python
import flytekit as fl
from openai import OpenAI

model_name = "google/gemma-2b-it"
hf_token_key = "vllm_hf_token"

vllm_args = {
"model": model_name,
"dtype": "half",
"max-model-len": 2000,
}

hf_secrets = HFSecret(
secrets_prefix="_FSEC_",
hf_token_key=hf_token_key
)

vllm_instance = VLLM(
hf_secret=hf_secrets,
arg_dict=vllm_args
)

image = fl.ImageSpec(
name="vllm_serve",
registry="...",
packages=["flytekitplugins-inference"],
)


@fl.task(
pod_template=vllm_instance.pod_template,
container_image=image,
secret_requests=[
fl.Secret(
key=hf_token_key, mount_requirement=fl.Secret.MountType.ENV_VAR # must be mounted as an env var
)
],
)
def model_serving() -> str:
client = OpenAI(
base_url=f"{vllm_instance.base_url}/v1", api_key="vllm" # api key required but ignored
)

completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": "Compose a haiku about the power of AI.",
}
],
temperature=0.5,
top_p=1,
max_tokens=1024,
)
return completion.choices[0].message.content
```
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .nim.serve import NIM, NIMSecrets
from .ollama.serve import Model, Ollama
from .vllm.serve import VLLM, HFSecret
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import Optional

from ..sidecar_template import ModelInferenceTemplate


@dataclass
class HFSecret:
"""
:param secrets_prefix: The secrets prefix that Flyte appends to all mounted secrets.
:param hf_token_group: The group name for the HuggingFace token.
:param hf_token_key: The key name for the HuggingFace token.
"""

secrets_prefix: str # _UNION_ or _FSEC_
hf_token_key: str
hf_token_group: Optional[str] = None


class VLLM(ModelInferenceTemplate):
def __init__(
self,
hf_secret: HFSecret,
arg_dict: Optional[dict] = None,
image: str = "vllm/vllm-openai",
health_endpoint: str = "/health",
port: int = 8000,
cpu: int = 2,
gpu: int = 1,
mem: str = "10Gi",
):
"""
Initialize NIM class for managing a Kubernetes pod template.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Initialize NIM class for managing a Kubernetes pod template.
Initialize VLLM class for managing a Kubernetes pod template.


:param hf_secret: Instance of HFSecret for managing hugging face secrets.
:param arg_dict: A dictionary of arguments for the VLLM model server (https://docs.vllm.ai/en/stable/models/engine_args.html).
:param image: The Docker image to be used for the model server container. Default is "ç".
samhita-alla marked this conversation as resolved.
Show resolved Hide resolved
:param health_endpoint: The health endpoint for the model server container. Default is "/health".
:param port: The port number for the model server container. Default is 8000.
:param cpu: The number of CPU cores requested for the model server container. Default is 2.
:param gpu: The number of GPU cores requested for the model server container. Default is 1.
:param mem: The amount of memory requested for the model server container. Default is "10Gi".
"""
if hf_secret.hf_token_key is None:
raise ValueError("HuggingFace token key must be provided.")
if hf_secret.secrets_prefix is None:
raise ValueError("Secrets prefix must be provided.")

self._hf_secret = hf_secret
self._arg_dict = arg_dict

super().__init__(
image=image,
health_endpoint=health_endpoint,
port=port,
cpu=cpu,
gpu=gpu,
mem=mem,
)

self.setup_vllm_pod_template()

def setup_vllm_pod_template(self):
from kubernetes.client.models import V1EnvVar

model_server_container = self.pod_template.pod_spec.init_containers[0]

if self._hf_secret.hf_token_group:
hf_key = f"$({self._hf_secret.secrets_prefix}{self._hf_secret.hf_token_group}_{self._hf_secret.hf_token_key})".upper()
else:
hf_key = f"$({self._hf_secret.secrets_prefix}{self._hf_secret.hf_token_key})".upper()

model_server_container.env = [
V1EnvVar(name="HUGGING_FACE_HUB_TOKEN", value=hf_key),
]
model_server_container.args = self.build_vllm_args()

def build_vllm_args(self) -> list:
args = []
if self._arg_dict:
for key, value in self._arg_dict.items():
args.append(f"--{key}")
if value is not None:
args.append(str(value))
return args
1 change: 1 addition & 0 deletions plugins/flytekit-inference/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
f"flytekitplugins.{PLUGIN_NAME}",
f"flytekitplugins.{PLUGIN_NAME}.nim",
f"flytekitplugins.{PLUGIN_NAME}.ollama",
f"flytekitplugins.{PLUGIN_NAME}.vllm",
],
install_requires=plugin_requires,
license="apache2",
Expand Down
60 changes: 60 additions & 0 deletions plugins/flytekit-inference/tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from flytekitplugins.inference import VLLM, HFSecret


def test_vllm_init_valid_params():
vllm_args = {
"model": "google/gemma-2b-it",
"dtype": "half",
"max-model-len": 2000,
}

hf_secrets = HFSecret(
secrets_prefix="_UNION_",
hf_token_key="vllm_hf_token"
)

vllm_instance = VLLM(
hf_secret=hf_secrets,
arg_dict=vllm_args,
image='vllm/vllm-openai:my-tag',
cpu='10',
gpu='2',
mem='50Gi',
port=8080,
)

assert len(vllm_instance.pod_template.pod_spec.init_containers) == 1
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].image
== 'vllm/vllm-openai:my-tag'
)
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].resources.requests[
"memory"
]
== "50Gi"
)
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].ports[0].container_port
== 8080
)
assert vllm_instance.pod_template.pod_spec.init_containers[0].args == ['--model', 'google/gemma-2b-it', '--dtype', 'half', '--max-model-len', '2000']
assert vllm_instance.pod_template.pod_spec.init_containers[0].env[0].name == 'HUGGING_FACE_HUB_TOKEN'
assert vllm_instance.pod_template.pod_spec.init_containers[0].env[0].value == '$(_UNION_VLLM_HF_TOKEN)'



def test_vllm_default_params():
vllm_instance = VLLM(hf_secret=HFSecret(secrets_prefix="_FSEC_", hf_token_key="test_token"))

assert vllm_instance.base_url == "http://localhost:8000"
assert vllm_instance._image == 'vllm/vllm-openai'
assert vllm_instance._port == 8000
assert vllm_instance._cpu == 2
assert vllm_instance._gpu == 1
assert vllm_instance._health_endpoint == "/health"
assert vllm_instance._mem == "10Gi"
assert vllm_instance._arg_dict == None
assert vllm_instance._hf_secret.secrets_prefix == '_FSEC_'
assert vllm_instance._hf_secret.hf_token_key == 'test_token'
assert vllm_instance._hf_secret.hf_token_group == None
Loading