From a03f546588b73786d823e6b8fb68e8a836ab4fad Mon Sep 17 00:00:00 2001
From: nkwangleiGIT <nkwanglei@126.com>
Date: Thu, 14 Nov 2024 16:33:36 +0800
Subject: [PATCH] [k8s] support to use custom gpu resource name if it's not
 nvidia.com/gpu

Signed-off-by: nkwangleiGIT <nkwanglei@126.com>
---
 sky/clouds/kubernetes.py             |  2 +-
 sky/provision/kubernetes/instance.py | 11 ++++++-----
 sky/provision/kubernetes/utils.py    | 23 +++++++++++++++++++----
 sky/utils/kubernetes/gpu_labeler.py  |  4 ++--
 4 files changed, 28 insertions(+), 12 deletions(-)

diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py
index 5e1b46d52eb..aec7bdfad63 100644
--- a/sky/clouds/kubernetes.py
+++ b/sky/clouds/kubernetes.py
@@ -378,7 +378,7 @@ def make_deploy_resources_variables(
                 tpu_requested = True
                 k8s_resource_key = kubernetes_utils.TPU_RESOURCE_KEY
             else:
-                k8s_resource_key = kubernetes_utils.GPU_RESOURCE_KEY
+                k8s_resource_key = kubernetes_utils.get_gpu_resource_key()
 
         port_mode = network_utils.get_port_mode(None)
 
diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py
index 2dcf38f2365..437a8528c49 100644
--- a/sky/provision/kubernetes/instance.py
+++ b/sky/provision/kubernetes/instance.py
@@ -179,6 +179,7 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes):
                             #  case we will need to update this logic.
                             # TODO(Doyoung): Update the error message raised
                             # with the multi-host TPU support.
+                            gpu_resource_key = kubernetes_utils.get_gpu_resource_key()
                             if 'Insufficient google.com/tpu' in event_message:
                                 extra_msg = (
                                     f'Verify if '
@@ -191,14 +192,14 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes):
                                                        pod,
                                                        extra_msg,
                                                        details=event_message))
-                            elif (('Insufficient nvidia.com/gpu'
+                            elif ((f'Insufficient {gpu_resource_key}'
                                    in event_message) or
                                   ('didn\'t match Pod\'s node affinity/selector'
                                    in event_message)):
                                 extra_msg = (
-                                    f'Verify if '
-                                    f'{pod.spec.node_selector[label_key]}'
-                                    ' is available in the cluster.')
+                                    f'Verify if any node matching label  '
+                                    f'{pod.spec.node_selector[label_key]} and sufficient '
+                                    f'resource {gpu_resource_key} is available in the cluster.')
                                 raise config_lib.KubernetesError(
                                     _lack_resource_msg('GPU',
                                                        pod,
@@ -685,7 +686,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
     limits = pod_spec['spec']['containers'][0].get('resources',
                                                    {}).get('limits')
     if limits is not None:
-        needs_gpus = limits.get(kubernetes_utils.GPU_RESOURCE_KEY, 0) > 0
+        needs_gpus = limits.get(kubernetes_utils.get_gpu_resource_key(), 0) > 0
 
     # TPU pods provisioned on GKE use the default containerd runtime.
     # Reference: https://cloud.google.com/kubernetes-engine/docs/how-to/migrate-containerd#overview  # pylint: disable=line-too-long
diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py
index e5bc4228a8e..3e46fb8af4f 100644
--- a/sky/provision/kubernetes/utils.py
+++ b/sky/provision/kubernetes/utils.py
@@ -438,7 +438,7 @@ def detect_accelerator_resource(
     nodes = get_kubernetes_nodes(context)
     for node in nodes:
         cluster_resources.update(node.status.allocatable.keys())
-    has_accelerator = (GPU_RESOURCE_KEY in cluster_resources or
+    has_accelerator = (get_gpu_resource_key() in cluster_resources or
                        TPU_RESOURCE_KEY in cluster_resources)
 
     return has_accelerator, cluster_resources
@@ -2233,10 +2233,11 @@ def get_node_accelerator_count(attribute_dict: dict) -> int:
         Number of accelerators allocated or available from the node. If no
             resource is found, it returns 0.
     """
-    assert not (GPU_RESOURCE_KEY in attribute_dict and
+    gpu_resource_name = get_gpu_resource_key()
+    assert not (gpu_resource_name in attribute_dict and
                 TPU_RESOURCE_KEY in attribute_dict)
-    if GPU_RESOURCE_KEY in attribute_dict:
-        return int(attribute_dict[GPU_RESOURCE_KEY])
+    if gpu_resource_name in attribute_dict:
+        return int(attribute_dict[gpu_resource_name])
     elif TPU_RESOURCE_KEY in attribute_dict:
         return int(attribute_dict[TPU_RESOURCE_KEY])
     return 0
@@ -2395,3 +2396,17 @@ def process_skypilot_pods(
         num_pods = len(cluster.pods)
         cluster.resources_str = f'{num_pods}x {cluster.resources}'
     return list(clusters.values()), jobs_controllers, serve_controllers
+
+
+def get_gpu_resource_key():
+    """Get the GPU resource name to use in kubernetes.
+    The function first checks for an environment variable.
+    If defined, it uses its value; otherwise, it returns the default value.
+    Args:
+        name (str): Default GPU resource name, default is "nvidia.com/gpu".
+    Returns:
+        str: The selected GPU resource name.
+    """
+    # Retrieve GPU resource name from environment variable, if set. Else use default.
+    # E.g., can be nvidia.com/gpu-h100, amd.com/gpu etc.
+    return os.getenv('CUSTOM_GPU_RESOURCE_KEY', default = GPU_RESOURCE_KEY)
diff --git a/sky/utils/kubernetes/gpu_labeler.py b/sky/utils/kubernetes/gpu_labeler.py
index 14fbbdedca5..5618cac915c 100644
--- a/sky/utils/kubernetes/gpu_labeler.py
+++ b/sky/utils/kubernetes/gpu_labeler.py
@@ -101,7 +101,7 @@ def label():
         # Get the list of nodes with GPUs
         gpu_nodes = []
         for node in nodes:
-            if kubernetes_utils.GPU_RESOURCE_KEY in node.status.capacity:
+            if kubernetes_utils.get_gpu_resource_key() in node.status.capacity:
                 gpu_nodes.append(node)
 
         print(f'Found {len(gpu_nodes)} GPU nodes in the cluster')
@@ -142,7 +142,7 @@ def label():
     if len(gpu_nodes) == 0:
         print('No GPU nodes found in the cluster. If you have GPU nodes, '
               'please ensure that they have the label '
-              f'`{kubernetes_utils.GPU_RESOURCE_KEY}: <number of GPUs>`')
+              f'`{kubernetes_utils.get_gpu_resource_key()}: <number of GPUs>`')
     else:
         print('GPU labeling started - this may take 10 min or more to complete.'
               '\nTo check the status of GPU labeling jobs, run '