Skip to content

Commit

Permalink
feature: Support choosing Kubernetes QoS class through the decorator (N…
Browse files Browse the repository at this point in the history
  • Loading branch information
saikonen authored Dec 5, 2024
1 parent 9e532f2 commit 0469eeb
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 55 deletions.
2 changes: 2 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@
KUBERNETES_CPU = from_conf("KUBERNETES_CPU", None)
KUBERNETES_MEMORY = from_conf("KUBERNETES_MEMORY", None)
KUBERNETES_DISK = from_conf("KUBERNETES_DISK", None)
# Default kubernetes QoS class
KUBERNETES_QOS = from_conf("KUBERNETES_QOS", "burstable")

ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "")
ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "")
Expand Down
35 changes: 18 additions & 17 deletions metaflow/plugins/airflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# TODO: Move chevron to _vendor
from metaflow.plugins.cards.card_modules import chevron
from metaflow.plugins.kubernetes.kubernetes import Kubernetes
from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
from metaflow.util import compress_list, dict_to_cli_options, get_username

Expand Down Expand Up @@ -428,25 +429,25 @@ def _to_job(self, node):
if k8s_deco.attributes["namespace"] is not None
else "default"
)

qos_requests, qos_limits = qos_requests_and_limits(
k8s_deco.attributes["qos"],
k8s_deco.attributes["cpu"],
k8s_deco.attributes["memory"],
k8s_deco.attributes["disk"],
)
resources = dict(
requests={
"cpu": k8s_deco.attributes["cpu"],
"memory": "%sM" % str(k8s_deco.attributes["memory"]),
"ephemeral-storage": str(k8s_deco.attributes["disk"]),
}
requests=qos_requests,
limits={
**qos_limits,
**{
"%s.com/gpu".lower()
% k8s_deco.attributes["gpu_vendor"]: str(k8s_deco.attributes["gpu"])
for k in [0]
# Don't set GPU limits if gpu isn't specified.
if k8s_deco.attributes["gpu"] is not None
},
},
)
if k8s_deco.attributes["gpu"] is not None:
resources.update(
dict(
limits={
"%s.com/gpu".lower()
% k8s_deco.attributes["gpu_vendor"]: str(
k8s_deco.attributes["gpu"]
)
}
)
)

annotations = {
"metaflow/production_token": self.production_token,
Expand Down
29 changes: 19 additions & 10 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from metaflow.metaflow_config_funcs import config_values
from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
from metaflow.parameters import deploy_time_eval
from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits
from metaflow.plugins.kubernetes.kubernetes import (
parse_kube_keyvalue_list,
validate_kube_labels,
Expand Down Expand Up @@ -1842,6 +1843,13 @@ def _container_templates(self):
if tmpfs_enabled and tmpfs_tempdir:
env["METAFLOW_TEMPDIR"] = tmpfs_path

qos_requests, qos_limits = qos_requests_and_limits(
resources["qos"],
resources["cpu"],
resources["memory"],
resources["disk"],
)

# Create a ContainerTemplate for this node. Ideally, we would have
# liked to inline this ContainerTemplate and avoid scanning the workflow
# twice, but due to issues with variable substitution, we will have to
Expand Down Expand Up @@ -1905,6 +1913,7 @@ def _container_templates(self):
persistent_volume_claims=resources["persistent_volume_claims"],
shared_memory=shared_memory,
port=port,
qos=resources["qos"],
)

for k, v in env.items():
Expand Down Expand Up @@ -2090,17 +2099,17 @@ def _container_templates(self):
image=resources["image"],
image_pull_policy=resources["image_pull_policy"],
resources=kubernetes_sdk.V1ResourceRequirements(
requests={
"cpu": str(resources["cpu"]),
"memory": "%sM" % str(resources["memory"]),
"ephemeral-storage": "%sM"
% str(resources["disk"]),
},
requests=qos_requests,
limits={
"%s.com/gpu".lower()
% resources["gpu_vendor"]: str(resources["gpu"])
for k in [0]
if resources["gpu"] is not None
**qos_limits,
**{
"%s.com/gpu".lower()
% resources["gpu_vendor"]: str(
resources["gpu"]
)
for k in [0]
if resources["gpu"] is not None
},
},
),
# Configure secrets
Expand Down
29 changes: 29 additions & 0 deletions metaflow/plugins/kubernetes/kube_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,32 @@ def parse_cli_options(flow_name, run_id, user, my_runs, echo):
raise CommandException("A previous run id was not found. Specify --run-id.")

return flow_name, run_id, user


def qos_requests_and_limits(qos: str, cpu: int, memory: int, storage: int):
"return resource requests and limits for the kubernetes pod based on the given QoS Class"
# case insensitive matching for QoS class
qos = qos.lower()
# Determine the requests and limits to define chosen QoS class
qos_limits = {}
qos_requests = {}
if qos == "guaranteed":
# Guaranteed - has both cpu/memory limits. requests not required, as these will be inferred.
qos_limits = {
"cpu": str(cpu),
"memory": "%sM" % str(memory),
"ephemeral-storage": "%sM" % str(storage),
}
# NOTE: Even though Kubernetes will produce matching requests for the specified limits, this happens late in the lifecycle.
# We specify them explicitly here to make some K8S tooling happy, in case they rely on .resources.requests being present at time of submitting the job.
qos_requests = qos_limits
else:
# Burstable - not Guaranteed, and has a memory/cpu limit or request
qos_requests = {
"cpu": str(cpu),
"memory": "%sM" % str(memory),
"ephemeral-storage": "%sM" % str(storage),
}
# TODO: Add support for BestEffort once there is a use case for it.
# BestEffort - no limit or requests for cpu/memory
return qos_requests, qos_limits
4 changes: 4 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def create_jobset(
shared_memory=None,
port=None,
num_parallel=None,
qos=None,
):
name = "js-%s" % str(uuid4())[:6]
jobset = (
Expand Down Expand Up @@ -228,6 +229,7 @@ def create_jobset(
shared_memory=shared_memory,
port=port,
num_parallel=num_parallel,
qos=qos,
)
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
.environment_variable("METAFLOW_CODE_URL", code_package_url)
Expand Down Expand Up @@ -488,6 +490,7 @@ def create_job_object(
shared_memory=None,
port=None,
name_pattern=None,
qos=None,
):
if env is None:
env = {}
Expand Down Expand Up @@ -528,6 +531,7 @@ def create_job_object(
persistent_volume_claims=persistent_volume_claims,
shared_memory=shared_memory,
port=port,
qos=qos,
)
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
.environment_variable("METAFLOW_CODE_URL", code_package_url)
Expand Down
8 changes: 8 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def kubernetes():
type=int,
help="Number of parallel nodes to run as a multi-node job.",
)
@click.option(
"--qos",
default=None,
type=str,
help="Quality of Service class for the Kubernetes pod",
)
@click.pass_context
def step(
ctx,
Expand Down Expand Up @@ -154,6 +160,7 @@ def step(
shared_memory=None,
port=None,
num_parallel=None,
qos=None,
**kwargs
):
def echo(msg, stream="stderr", job_id=None, **kwargs):
Expand Down Expand Up @@ -294,6 +301,7 @@ def _sync_metadata():
shared_memory=shared_memory,
port=port,
num_parallel=num_parallel,
qos=qos,
)
except Exception as e:
traceback.print_exc(chain=False)
Expand Down
17 changes: 17 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
KUBERNETES_SERVICE_ACCOUNT,
KUBERNETES_SHARED_MEMORY,
KUBERNETES_TOLERATIONS,
KUBERNETES_QOS,
)
from metaflow.plugins.resources_decorator import ResourcesDecorator
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
Expand All @@ -41,6 +42,8 @@
unicode = str
basestring = str

SUPPORTED_KUBERNETES_QOS_CLASSES = ["Guaranteed", "Burstable"]


class KubernetesDecorator(StepDecorator):
"""
Expand Down Expand Up @@ -109,6 +112,8 @@ class KubernetesDecorator(StepDecorator):
hostname_resolution_timeout: int, default 10 * 60
Timeout in seconds for the workers tasks in the gang scheduled cluster to resolve the hostname of control task.
Only applicable when @parallel is used.
qos: str, default: Burstable
Quality of Service class to assign to the pod. Supported values are: Guaranteed, Burstable, BestEffort
"""

name = "kubernetes"
Expand Down Expand Up @@ -136,6 +141,7 @@ class KubernetesDecorator(StepDecorator):
"compute_pool": None,
"executable": None,
"hostname_resolution_timeout": 10 * 60,
"qos": KUBERNETES_QOS,
}
package_url = None
package_sha = None
Expand Down Expand Up @@ -259,6 +265,17 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
self.step = step
self.flow_datastore = flow_datastore

if (
self.attributes["qos"] is not None
# case insensitive matching.
and self.attributes["qos"].lower()
not in [c.lower() for c in SUPPORTED_KUBERNETES_QOS_CLASSES]
):
raise MetaflowException(
"*%s* is not a valid Kubernetes QoS class. Choose one of the following: %s"
% (self.attributes["qos"], ", ".join(SUPPORTED_KUBERNETES_QOS_CLASSES))
)

if any([deco.name == "batch" for deco in decos]):
raise MetaflowException(
"Step *{step}* is marked for execution both on AWS Batch and "
Expand Down
33 changes: 20 additions & 13 deletions metaflow/plugins/kubernetes/kubernetes_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
KubernetesJobSet,
) # We need this import for Kubernetes Client.

from .kube_utils import qos_requests_and_limits


class KubernetesJobException(MetaflowException):
headline = "Kubernetes job error"
Expand Down Expand Up @@ -74,6 +76,13 @@ def create_job_spec(self):
if self._kwargs["shared_memory"]
else None
)
qos_requests, qos_limits = qos_requests_and_limits(
self._kwargs["qos"],
self._kwargs["cpu"],
self._kwargs["memory"],
self._kwargs["disk"],
)

return client.V1JobSpec(
# Retries are handled by Metaflow when it is responsible for
# executing the flow. The responsibility is moved to Kubernetes
Expand Down Expand Up @@ -154,20 +163,18 @@ def create_job_spec(self):
image_pull_policy=self._kwargs["image_pull_policy"],
name=self._kwargs["step_name"].replace("_", "-"),
resources=client.V1ResourceRequirements(
requests={
"cpu": str(self._kwargs["cpu"]),
"memory": "%sM" % str(self._kwargs["memory"]),
"ephemeral-storage": "%sM"
% str(self._kwargs["disk"]),
},
requests=qos_requests,
limits={
"%s.com/gpu".lower()
% self._kwargs["gpu_vendor"]: str(
self._kwargs["gpu"]
)
for k in [0]
# Don't set GPU limits if gpu isn't specified.
if self._kwargs["gpu"] is not None
**qos_limits,
**{
"%s.com/gpu".lower()
% self._kwargs["gpu_vendor"]: str(
self._kwargs["gpu"]
)
for k in [0]
# Don't set GPU limits if gpu isn't specified.
if self._kwargs["gpu"] is not None
},
},
),
volume_mounts=(
Expand Down
34 changes: 19 additions & 15 deletions metaflow/plugins/kubernetes/kubernetes_jobsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from metaflow.tracing import inject_tracing_vars
from metaflow.metaflow_config import KUBERNETES_SECRETS

from .kube_utils import qos_requests_and_limits


class KubernetesJobsetException(MetaflowException):
headline = "Kubernetes jobset error"
Expand Down Expand Up @@ -554,7 +556,12 @@ def dump(self):
if self._kwargs["shared_memory"]
else None
)

qos_requests, qos_limits = qos_requests_and_limits(
self._kwargs["qos"],
self._kwargs["cpu"],
self._kwargs["memory"],
self._kwargs["disk"],
)
return dict(
name=self.name,
template=client.api_client.ApiClient().sanitize_for_serialization(
Expand Down Expand Up @@ -653,21 +660,18 @@ def dump(self):
"_", "-"
),
resources=client.V1ResourceRequirements(
requests={
"cpu": str(self._kwargs["cpu"]),
"memory": "%sM"
% str(self._kwargs["memory"]),
"ephemeral-storage": "%sM"
% str(self._kwargs["disk"]),
},
requests=qos_requests,
limits={
"%s.com/gpu".lower()
% self._kwargs["gpu_vendor"]: str(
self._kwargs["gpu"]
)
for k in [0]
# Don't set GPU limits if gpu isn't specified.
if self._kwargs["gpu"] is not None
**qos_limits,
**{
"%s.com/gpu".lower()
% self._kwargs["gpu_vendor"]: str(
self._kwargs["gpu"]
)
for k in [0]
# Don't set GPU limits if gpu isn't specified.
if self._kwargs["gpu"] is not None
},
},
),
volume_mounts=(
Expand Down

0 comments on commit 0469eeb

Please sign in to comment.