-
Notifications
You must be signed in to change notification settings - Fork 531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[k8s] support to use custom gpu resource name if it's not nvidia.com/gpu #4337
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -438,7 +438,7 @@ def detect_accelerator_resource( | |||
nodes = get_kubernetes_nodes(context) | ||||
for node in nodes: | ||||
cluster_resources.update(node.status.allocatable.keys()) | ||||
has_accelerator = (GPU_RESOURCE_KEY in cluster_resources or | ||||
has_accelerator = (get_gpu_resource_key() in cluster_resources or | ||||
TPU_RESOURCE_KEY in cluster_resources) | ||||
|
||||
return has_accelerator, cluster_resources | ||||
|
@@ -2253,10 +2253,11 @@ def get_node_accelerator_count(attribute_dict: dict) -> int: | |||
Number of accelerators allocated or available from the node. If no | ||||
resource is found, it returns 0. | ||||
""" | ||||
assert not (GPU_RESOURCE_KEY in attribute_dict and | ||||
gpu_resource_name = get_gpu_resource_key() | ||||
assert not (gpu_resource_name in attribute_dict and | ||||
TPU_RESOURCE_KEY in attribute_dict) | ||||
if GPU_RESOURCE_KEY in attribute_dict: | ||||
return int(attribute_dict[GPU_RESOURCE_KEY]) | ||||
if gpu_resource_name in attribute_dict: | ||||
return int(attribute_dict[gpu_resource_name]) | ||||
elif TPU_RESOURCE_KEY in attribute_dict: | ||||
return int(attribute_dict[TPU_RESOURCE_KEY]) | ||||
return 0 | ||||
|
@@ -2415,3 +2416,18 @@ def process_skypilot_pods( | |||
num_pods = len(cluster.pods) | ||||
cluster.resources_str = f'{num_pods}x {cluster.resources}' | ||||
return list(clusters.values()), jobs_controllers, serve_controllers | ||||
|
||||
|
||||
def get_gpu_resource_key(): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also update L194 in sky/provision/kubernetes/instance.py to use get_gpu_resource_key() to make sure errors are handled correctly? skypilot/sky/provision/kubernetes/instance.py Line 194 in c1726ae
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||||
"""Get the GPU resource name to use in kubernetes. | ||||
The function first checks for an environment variable. | ||||
If defined, it uses its value; otherwise, it returns the default value. | ||||
Args: | ||||
name (str): Default GPU resource name, default is "nvidia.com/gpu". | ||||
Returns: | ||||
str: The selected GPU resource name. | ||||
""" | ||||
# Retrieve GPU resource name from environment variable, if set. | ||||
# Else use default. | ||||
# E.g., can be nvidia.com/gpu-h100, amd.com/gpu etc. | ||||
return os.getenv('CUSTOM_GPU_RESOURCE_KEY', default=GPU_RESOURCE_KEY) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: