Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[k8s] support to use custom gpu resource name if it's not nvidia.com/gpu #4337

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def make_deploy_resources_variables(
tpu_requested = True
k8s_resource_key = kubernetes_utils.TPU_RESOURCE_KEY
else:
k8s_resource_key = kubernetes_utils.GPU_RESOURCE_KEY
k8s_resource_key = kubernetes_utils.get_gpu_resource_key()

port_mode = network_utils.get_port_mode(None)

Expand Down
12 changes: 7 additions & 5 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes):
# case we will need to update this logic.
# TODO(Doyoung): Update the error message raised
# with the multi-host TPU support.
gpu_resource_key = kubernetes_utils.get_gpu_resource_key() # pylint: disable=line-too-long
if 'Insufficient google.com/tpu' in event_message:
extra_msg = (
f'Verify if '
Expand All @@ -192,14 +193,15 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes):
pod,
extra_msg,
details=event_message))
elif (('Insufficient nvidia.com/gpu'
elif ((f'Insufficient {gpu_resource_key}'
in event_message) or
('didn\'t match Pod\'s node affinity/selector'
in event_message)):
extra_msg = (
f'Verify if '
f'{pod.spec.node_selector[label_key]}'
' is available in the cluster.')
f'Verify if any node matching label '
f'{pod.spec.node_selector[label_key]} and '
f'sufficient resource {gpu_resource_key} '
f'is available in the cluster.')
raise config_lib.KubernetesError(
_lack_resource_msg('GPU',
pod,
Expand Down Expand Up @@ -728,7 +730,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
limits = pod_spec['spec']['containers'][0].get('resources',
{}).get('limits')
if limits is not None:
needs_gpus = limits.get(kubernetes_utils.GPU_RESOURCE_KEY, 0) > 0
needs_gpus = limits.get(kubernetes_utils.get_gpu_resource_key(), 0) > 0

# TPU pods provisioned on GKE use the default containerd runtime.
# Reference: https://cloud.google.com/kubernetes-engine/docs/how-to/migrate-containerd#overview # pylint: disable=line-too-long
Expand Down
24 changes: 20 additions & 4 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def detect_accelerator_resource(
nodes = get_kubernetes_nodes(context)
for node in nodes:
cluster_resources.update(node.status.allocatable.keys())
has_accelerator = (GPU_RESOURCE_KEY in cluster_resources or
has_accelerator = (get_gpu_resource_key() in cluster_resources or
TPU_RESOURCE_KEY in cluster_resources)

return has_accelerator, cluster_resources
Expand Down Expand Up @@ -2253,10 +2253,11 @@ def get_node_accelerator_count(attribute_dict: dict) -> int:
Number of accelerators allocated or available from the node. If no
resource is found, it returns 0.
"""
assert not (GPU_RESOURCE_KEY in attribute_dict and
gpu_resource_name = get_gpu_resource_key()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
gpu_resource_name = get_gpu_resource_key()
gpu_resource_key = get_gpu_resource_key()

assert not (gpu_resource_name in attribute_dict and
TPU_RESOURCE_KEY in attribute_dict)
if GPU_RESOURCE_KEY in attribute_dict:
return int(attribute_dict[GPU_RESOURCE_KEY])
if gpu_resource_name in attribute_dict:
return int(attribute_dict[gpu_resource_name])
elif TPU_RESOURCE_KEY in attribute_dict:
return int(attribute_dict[TPU_RESOURCE_KEY])
return 0
Expand Down Expand Up @@ -2415,3 +2416,18 @@ def process_skypilot_pods(
num_pods = len(cluster.pods)
cluster.resources_str = f'{num_pods}x {cluster.resources}'
return list(clusters.values()), jobs_controllers, serve_controllers


def get_gpu_resource_key():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also update L194 in sky/provision/kubernetes/instance.py to use get_gpu_resource_key() to make sure errors are handled correctly?

elif (('Insufficient nvidia.com/gpu'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"""Get the GPU resource name to use in kubernetes.
The function first checks for an environment variable.
If defined, it uses its value; otherwise, it returns the default value.
Args:
name (str): Default GPU resource name, default is "nvidia.com/gpu".
Returns:
str: The selected GPU resource name.
"""
# Retrieve GPU resource name from environment variable, if set.
# Else use default.
# E.g., can be nvidia.com/gpu-h100, amd.com/gpu etc.
return os.getenv('CUSTOM_GPU_RESOURCE_KEY', default=GPU_RESOURCE_KEY)
4 changes: 2 additions & 2 deletions sky/utils/kubernetes/gpu_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def label():
# Get the list of nodes with GPUs
gpu_nodes = []
for node in nodes:
if kubernetes_utils.GPU_RESOURCE_KEY in node.status.capacity:
if kubernetes_utils.get_gpu_resource_key() in node.status.capacity:
gpu_nodes.append(node)

print(f'Found {len(gpu_nodes)} GPU nodes in the cluster')
Expand Down Expand Up @@ -142,7 +142,7 @@ def label():
if len(gpu_nodes) == 0:
print('No GPU nodes found in the cluster. If you have GPU nodes, '
'please ensure that they have the label '
f'`{kubernetes_utils.GPU_RESOURCE_KEY}: <number of GPUs>`')
f'`{kubernetes_utils.get_gpu_resource_key()}: <number of GPUs>`')
else:
print('GPU labeling started - this may take 10 min or more to complete.'
'\nTo check the status of GPU labeling jobs, run '
Expand Down
Loading