Skip to content

Commit

Permalink
Add sky show-gpus support for Kubernetes (#2638)
Browse files Browse the repository at this point in the history
* Add sky show-gpus support for Kubernetes

* Update sky/clouds/service_catalog/kubernetes_catalog.py

Co-authored-by: Romil Bhardwaj <[email protected]>

* PR feedback

* PR feedback part 2

* Format fix

* PR feedback part 3

* Fix bug with checking enabled clouds in k8s list_accelerators

* Pylint fixes

* Pylint fixes part 2

* Pylint fixes part 3

* Pylint fixes part 4

---------

Co-authored-by: Romil Bhardwaj <[email protected]>
  • Loading branch information
hemildesai and romilbhardwaj authored Oct 7, 2023
1 parent 69b4752 commit 9ff1927
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 4 deletions.
12 changes: 9 additions & 3 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3238,9 +3238,6 @@ def show_gpus(
type is the lowest across all regions for both on-demand and spot
instances. There may be multiple regions with the same lowest price.
"""
# validation for the --cloud kubernetes
if cloud == 'kubernetes':
raise click.UsageError('Kubernetes does not have a service catalog.')
# validation for the --region flag
if region is not None and cloud is None:
raise click.UsageError(
Expand Down Expand Up @@ -3271,6 +3268,11 @@ def _output():
clouds=cloud,
region_filter=region,
)

if len(result) == 0 and cloud == 'kubernetes':
yield kubernetes_utils.NO_GPU_ERROR_MESSAGE
return

# "Common" GPUs
for gpu in service_catalog.get_common_gpus():
if gpu in result:
Expand Down Expand Up @@ -3327,6 +3329,10 @@ def _output():
case_sensitive=False)

if len(result) == 0:
if cloud == 'kubernetes':
yield kubernetes_utils.NO_GPU_ERROR_MESSAGE
return

quantity_str = (f' with requested quantity {quantity}'
if quantity else '')
yield f'Resources \'{name}\'{quantity_str} not found. '
Expand Down
7 changes: 7 additions & 0 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
if clouds is None:
clouds = list(_ALL_CLOUDS)

# TODO(hemil): Remove this once the common service catalog
# functions are refactored from clouds/kubernetes.py to
# kubernetes_catalog.py and add kubernetes to _ALL_CLOUDS
if method_name == 'list_accelerators':
clouds.append('kubernetes')

single = isinstance(clouds, str)
if single:
clouds = [clouds] # type: ignore
Expand Down
77 changes: 76 additions & 1 deletion sky/clouds/service_catalog/kubernetes_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
Kubernetes does not require a catalog of instances, but we need an image catalog
mapping SkyPilot image tags to corresponding container image tags.
"""
from typing import Dict, List, Optional, Set, Tuple

from typing import Optional
import pandas as pd

from sky import global_user_state
from sky.clouds import Kubernetes
from sky.clouds.service_catalog import CloudFilter
from sky.clouds.service_catalog import common
from sky.utils import kubernetes_utils

_PULL_FREQUENCY_HOURS = 7

Expand All @@ -26,3 +31,73 @@ def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]:
def is_image_tag_valid(tag: str, region: Optional[str]) -> bool:
"""Returns whether the image tag is valid."""
return common.is_image_tag_valid_impl(_image_df, tag, region)


def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
quantity_filter: Optional[int],
case_sensitive: bool = True
) -> Dict[str, List[common.InstanceTypeInfo]]:
k8s_cloud = Kubernetes()
if not any(
map(k8s_cloud.is_same_cloud, global_user_state.get_enabled_clouds())
) or not kubernetes_utils.check_credentials()[0]:
return {}

has_gpu = kubernetes_utils.detect_gpu_resource()
if not has_gpu:
return {}

label_formatter, _ = kubernetes_utils.detect_gpu_label_formatter()
if not label_formatter:
return {}

accelerators: Set[Tuple[str, int]] = set()
key = label_formatter.get_label_key()
nodes = kubernetes_utils.get_kubernetes_nodes()
for node in nodes:
if key in node.metadata.labels:
accelerator_name = label_formatter.get_accelerator_from_label_value(
node.metadata.labels.get(key))
accelerator_count = int(
node.status.allocatable.get('nvidia.com/gpu', 0))

if accelerator_name and accelerator_count > 0:
for count in range(1, accelerator_count + 1):
accelerators.add((accelerator_name, count))

result = []
for accelerator_name, accelerator_count in accelerators:
result.append(
common.InstanceTypeInfo(cloud='Kubernetes',
instance_type=None,
accelerator_name=accelerator_name,
accelerator_count=accelerator_count,
cpu_count=None,
device_memory=None,
memory=None,
price=0.0,
spot_price=0.0,
region='kubernetes'))

df = pd.DataFrame(result,
columns=[
'Cloud', 'InstanceType', 'AcceleratorName',
'AcceleratorCount', 'vCPUs', 'DeviceMemoryGiB',
'MemoryGiB', 'Price', 'SpotPrice', 'Region'
])
df['GpuInfo'] = True

return common.list_accelerators_impl('Kubernetes', df, gpus_only,
name_filter, region_filter,
quantity_filter, case_sensitive)


def validate_region_zone(
region_name: Optional[str],
zone_name: Optional[str],
clouds: CloudFilter = None # pylint: disable=unused-argument
) -> Tuple[Optional[str], Optional[str]]:
return (region_name, zone_name)
27 changes: 27 additions & 0 deletions sky/utils/kubernetes_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
'T': 2**40,
'P': 2**50,
}
NO_GPU_ERROR_MESSAGE = 'No GPUs found in Kubernetes cluster. \
If your cluster contains GPUs, make sure nvidia.com/gpu resource is available on the nodes and the node labels for identifying GPUs \
(e.g., skypilot.co/accelerators) are setup correctly. \
To further debug, run: sky check.'

logger = sky_logging.init_logger(__name__)

Expand Down Expand Up @@ -78,6 +82,11 @@ def get_label_value(cls, accelerator: str) -> str:
"""Given a GPU type, returns the label value to be used"""
raise NotImplementedError

@classmethod
def get_accelerator_from_label_value(cls, value: str) -> str:
"""Given a label value, returns the GPU type"""
raise NotImplementedError


def get_gke_accelerator_name(accelerator: str) -> str:
"""Returns the accelerator name for GKE clusters
Expand Down Expand Up @@ -111,6 +120,10 @@ def get_label_value(cls, accelerator: str) -> str:
# See sky.utils.kubernetes.gpu_labeler.
return accelerator.lower()

@classmethod
def get_accelerator_from_label_value(cls, value: str) -> str:
return value.upper()


class CoreWeaveLabelFormatter(GPULabelFormatter):
"""CoreWeave label formatter
Expand All @@ -129,6 +142,10 @@ def get_label_key(cls) -> str:
def get_label_value(cls, accelerator: str) -> str:
return accelerator.upper()

@classmethod
def get_accelerator_from_label_value(cls, value: str) -> str:
return value


class GKELabelFormatter(GPULabelFormatter):
"""GKE label formatter
Expand All @@ -147,6 +164,16 @@ def get_label_key(cls) -> str:
def get_label_value(cls, accelerator: str) -> str:
return get_gke_accelerator_name(accelerator)

@classmethod
def get_accelerator_from_label_value(cls, value: str) -> str:
if value.startswith('nvidia-tesla-'):
return value.replace('nvidia-tesla-', '').upper()
elif value.startswith('nvidia-'):
return value.replace('nvidia-', '').upper()
else:
raise ValueError(
f'Invalid accelerator name in GKE cluster: {value}')


# LABEL_FORMATTER_REGISTRY stores the label formats SkyPilot will try to
# discover the accelerator type from. The order of the list is important, as
Expand Down

0 comments on commit 9ff1927

Please sign in to comment.