Skip to content

Commit eea13cc

Browse files
[k8s] On-demand single-host TPU support on GKE (#3947)
* initial version of TPU support on GKE * revert unnecesary change * revert * use TPU_LABEL_KEY constant * nit * nit * update detect_gpu_label_formatter() to use match_label_key() * tidy get_gpu_label_key_value * nit * update method name * update get_gke_accelerator_name to support TPU * add support for get_label_keys method due to TPU label key * syntax * update get_tpu_topology_label_key_value * nit * refactor error surfacing methods to have it work with TPU support * update toleration comment * support listing available TPUs and show-gpus for TPUs * nit * update help message * Update /tmp/tpu_logs dir's write permission * nit * nit * comment update on TPU resource lackage error handling * Update to use global constant instead of hard coded string of nvidia.com/gpu and google.com/tpu * add smoke test and make exec work on TPU pods * update smoke test to check if TPU is reachable. * add comment * nit * Comment on number of requested TPU chips for multi- and single- host TPU slice. * update method to check GKE supported TPU name * nit * move is_tpu_pod_slice to kubernetes_utils * update get_accelerator_from_label_value to use is_tpu_pod_slice method * nit * format * nit * check acc count support * preemptive TPU check * update check_tpu_fits * error msg update * merge get_tpu_topology_label_key_value into get_gpu_label_key_value * Update sky/provision/kubernetes/utils.py Co-authored-by: Tian Xia <[email protected]> * nit fixes * format * nit * Implement method for reading acc counts from node/pod object * assertion update for is_tpu_vm * Exclude multi-host TPUs to displayed from show-gpus * Notify users that multi-host TPUs are not supported from 'sky show-gpus' * format * nit * display warning message from show-gpus conditionally * update sky show-gpus * update get_accelerator_label_key_value * format * format * format * update comment * resolve review comments * update tpuvm_mnist.yaml * resolve comments * update display message for show-gpus * format --------- Co-authored-by: Tian Xia <[email protected]>
1 parent 2030398 commit eea13cc

File tree

12 files changed

+691
-267
lines changed

12 files changed

+691
-267
lines changed

examples/tpu/tpuvm_mnist.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ resources:
55

66
# The setup command. Will be run under the working directory.
77
setup: |
8-
git clone https://github.com/google/flax.git --branch v0.8.2
8+
git clone https://github.com/google/flax.git --branch v0.10.1
99
1010
conda activate flax
1111
if [ $? -eq 0 ]; then
@@ -15,7 +15,7 @@ setup: |
1515
conda activate flax
1616
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
1717
pip install \
18-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.25" \
18+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.35" \
1919
clu \
2020
tensorflow tensorflow-datasets
2121
pip install -e flax

sky/cli.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -3143,7 +3143,8 @@ def _get_kubernetes_realtime_gpu_table(
31433143
'in Kubernetes cluster. ')
31443144
debug_msg = ('To show available accelerators on kubernetes,'
31453145
' run: sky show-gpus --cloud kubernetes ')
3146-
full_err_msg = (err_msg + kubernetes_utils.NO_GPU_HELP_MESSAGE +
3146+
full_err_msg = (err_msg +
3147+
kubernetes_utils.NO_ACCELERATOR_HELP_MESSAGE +
31473148
debug_msg)
31483149
raise ValueError(full_err_msg)
31493150
for gpu, _ in sorted(counts.items()):
@@ -3161,11 +3162,12 @@ def _get_kubernetes_node_info_table(context: Optional[str]):
31613162

31623163
node_info_dict = kubernetes_utils.get_kubernetes_node_info(context)
31633164
for node_name, node_info in node_info_dict.items():
3164-
available = node_info.free['nvidia.com/gpu'] if node_info.free[
3165-
'nvidia.com/gpu'] != -1 else no_permissions_str
3165+
available = node_info.free[
3166+
'accelerators_available'] if node_info.free[
3167+
'accelerators_available'] != -1 else no_permissions_str
31663168
node_table.add_row([
3167-
node_name, node_info.gpu_type,
3168-
node_info.total['nvidia.com/gpu'], available
3169+
node_name, node_info.accelerator_type,
3170+
node_info.total['accelerator_count'], available
31693171
])
31703172
return node_table
31713173

@@ -3220,8 +3222,18 @@ def _output():
32203222
yield from k8s_realtime_table.get_string()
32213223
k8s_node_table = _get_kubernetes_node_info_table(context)
32223224
yield '\n\n'
3225+
# TODO(Doyoung): Update the message with the multi-host TPU
3226+
# support.
3227+
k8s_per_node_acc_message = (
3228+
'Kubernetes per node accelerator availability ')
3229+
if kubernetes_utils.multi_host_tpu_exists_in_cluster(
3230+
context):
3231+
k8s_per_node_acc_message += (
3232+
'(Note: Multi-host TPUs are detected and excluded '
3233+
'from the display as multi-host TPUs are not '
3234+
'supported.)')
32233235
yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
3224-
f'Kubernetes per node GPU availability'
3236+
f'{k8s_per_node_acc_message}'
32253237
f'{colorama.Style.RESET_ALL}\n')
32263238
yield from k8s_node_table.get_string()
32273239
if kubernetes_autoscaling:

sky/clouds/kubernetes.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,23 @@ def make_deploy_resources_variables(
362362

363363
k8s_acc_label_key = None
364364
k8s_acc_label_value = None
365+
k8s_topology_label_key = None
366+
k8s_topology_label_value = None
367+
k8s_resource_key = None
368+
tpu_requested = False
365369

366-
# If GPUs are requested, set node label to match the GPU type.
370+
# If GPU/TPUs are requested, set node label to match the GPU/TPU type.
367371
if acc_count > 0 and acc_type is not None:
368-
k8s_acc_label_key, k8s_acc_label_value = \
369-
kubernetes_utils.get_gpu_label_key_value(context, acc_type)
372+
(k8s_acc_label_key, k8s_acc_label_value, k8s_topology_label_key,
373+
k8s_topology_label_value) = (
374+
kubernetes_utils.get_accelerator_label_key_value(
375+
context, acc_type, acc_count))
376+
if (k8s_acc_label_key ==
377+
kubernetes_utils.GKELabelFormatter.TPU_LABEL_KEY):
378+
tpu_requested = True
379+
k8s_resource_key = kubernetes_utils.TPU_RESOURCE_KEY
380+
else:
381+
k8s_resource_key = kubernetes_utils.GPU_RESOURCE_KEY
370382

371383
port_mode = network_utils.get_port_mode(None)
372384

@@ -428,6 +440,10 @@ def make_deploy_resources_variables(
428440
'k8s_skypilot_system_namespace': _SKYPILOT_SYSTEM_NAMESPACE,
429441
'k8s_spot_label_key': spot_label_key,
430442
'k8s_spot_label_value': spot_label_value,
443+
'tpu_requested': tpu_requested,
444+
'k8s_topology_label_key': k8s_topology_label_key,
445+
'k8s_topology_label_value': k8s_topology_label_value,
446+
'k8s_resource_key': k8s_resource_key,
431447
'image_id': image_id,
432448
}
433449

sky/clouds/service_catalog/kubernetes_catalog.py

+83-65
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,16 @@ def list_accelerators_realtime(
104104
) or not kubernetes_utils.check_credentials(context)[0]:
105105
return {}, {}, {}
106106

107-
has_gpu = kubernetes_utils.detect_gpu_resource(context)
107+
has_gpu = kubernetes_utils.detect_accelerator_resource(context)
108108
if not has_gpu:
109109
return {}, {}, {}
110110

111-
label_formatter, _ = kubernetes_utils.detect_gpu_label_formatter(context)
112-
if not label_formatter:
111+
lf, _ = kubernetes_utils.detect_gpu_label_formatter(context)
112+
if not lf:
113113
return {}, {}, {}
114114

115115
accelerators_qtys: Set[Tuple[str, int]] = set()
116-
key = label_formatter.get_label_key()
116+
keys = lf.get_label_keys()
117117
nodes = kubernetes_utils.get_kubernetes_nodes(context)
118118
# Get the pods to get the real-time GPU usage
119119
try:
@@ -134,67 +134,85 @@ def list_accelerators_realtime(
134134
min_quantity_filter = quantity_filter if quantity_filter else 1
135135

136136
for node in nodes:
137-
if key in node.metadata.labels:
138-
allocated_qty = 0
139-
accelerator_name = label_formatter.get_accelerator_from_label_value(
140-
node.metadata.labels.get(key))
141-
142-
# Check if name_filter regex matches the accelerator_name
143-
regex_flags = 0 if case_sensitive else re.IGNORECASE
144-
if name_filter and not re.match(
145-
name_filter, accelerator_name, flags=regex_flags):
146-
continue
147-
148-
accelerator_count = int(
149-
node.status.allocatable.get('nvidia.com/gpu', 0))
150-
151-
# Generate the GPU quantities for the accelerators
152-
if accelerator_name and accelerator_count > 0:
153-
count = 1
154-
while count <= accelerator_count:
155-
accelerators_qtys.add((accelerator_name, count))
156-
count *= 2
157-
# Add the accelerator count if it's not already in the set
158-
# (e.g., if there's 12 GPUs, we should have qtys 1, 2, 4, 8, 12)
159-
if accelerator_count not in accelerators_qtys:
160-
accelerators_qtys.add((accelerator_name, accelerator_count))
161-
162-
if accelerator_count >= min_quantity_filter:
163-
quantized_count = (min_quantity_filter *
164-
(accelerator_count // min_quantity_filter))
165-
if accelerator_name not in total_accelerators_capacity:
166-
total_accelerators_capacity[
167-
accelerator_name] = quantized_count
168-
else:
169-
total_accelerators_capacity[
170-
accelerator_name] += quantized_count
171-
172-
if pods is None:
173-
# If we can't get the pods, we can't get the GPU usage
174-
total_accelerators_available[accelerator_name] = -1
175-
continue
176-
177-
for pod in pods:
178-
# Get all the pods running on the node
179-
if (pod.spec.node_name == node.metadata.name and
180-
pod.status.phase in ['Running', 'Pending']):
181-
# Iterate over all the containers in the pod and sum the
182-
# GPU requests
183-
for container in pod.spec.containers:
184-
if container.resources.requests:
185-
allocated_qty += int(
186-
container.resources.requests.get(
187-
'nvidia.com/gpu', 0))
188-
189-
accelerators_available = accelerator_count - allocated_qty
190-
191-
if accelerator_name not in total_accelerators_available:
192-
total_accelerators_available[accelerator_name] = 0
193-
if accelerators_available >= min_quantity_filter:
194-
quantized_availability = min_quantity_filter * (
195-
accelerators_available // min_quantity_filter)
196-
total_accelerators_available[
197-
accelerator_name] += quantized_availability
137+
for key in keys:
138+
if key in node.metadata.labels:
139+
allocated_qty = 0
140+
accelerator_name = lf.get_accelerator_from_label_value(
141+
node.metadata.labels.get(key))
142+
143+
# Exclude multi-host TPUs from being processed.
144+
# TODO(Doyoung): Remove the logic when adding support for
145+
# multi-host TPUs.
146+
if kubernetes_utils.is_multi_host_tpu(node.metadata.labels):
147+
continue
148+
149+
# Check if name_filter regex matches the accelerator_name
150+
regex_flags = 0 if case_sensitive else re.IGNORECASE
151+
if name_filter and not re.match(
152+
name_filter, accelerator_name, flags=regex_flags):
153+
continue
154+
155+
# Generate the accelerator quantities
156+
accelerator_count = (
157+
kubernetes_utils.get_node_accelerator_count(
158+
node.status.allocatable))
159+
160+
if accelerator_name and accelerator_count > 0:
161+
# TPUs are counted in a different way compared to GPUs.
162+
# Multi-node GPUs can be split into smaller units and be
163+
# provisioned, but TPUs are considered as an atomic unit.
164+
if kubernetes_utils.is_tpu_on_gke(accelerator_name):
165+
accelerators_qtys.add(
166+
(accelerator_name, accelerator_count))
167+
else:
168+
count = 1
169+
while count <= accelerator_count:
170+
accelerators_qtys.add((accelerator_name, count))
171+
count *= 2
172+
# Add the accelerator count if it's not already in the
173+
# set (e.g., if there's 12 GPUs, we should have qtys 1,
174+
# 2, 4, 8, 12)
175+
if accelerator_count not in accelerators_qtys:
176+
accelerators_qtys.add(
177+
(accelerator_name, accelerator_count))
178+
179+
if accelerator_count >= min_quantity_filter:
180+
quantized_count = (
181+
min_quantity_filter *
182+
(accelerator_count // min_quantity_filter))
183+
if accelerator_name not in total_accelerators_capacity:
184+
total_accelerators_capacity[
185+
accelerator_name] = quantized_count
186+
else:
187+
total_accelerators_capacity[
188+
accelerator_name] += quantized_count
189+
190+
if pods is None:
191+
# If we can't get the pods, we can't get the GPU usage
192+
total_accelerators_available[accelerator_name] = -1
193+
continue
194+
195+
for pod in pods:
196+
# Get all the pods running on the node
197+
if (pod.spec.node_name == node.metadata.name and
198+
pod.status.phase in ['Running', 'Pending']):
199+
# Iterate over all the containers in the pod and sum
200+
# the GPU requests
201+
for container in pod.spec.containers:
202+
if container.resources.requests:
203+
allocated_qty += (
204+
kubernetes_utils.get_node_accelerator_count(
205+
container.resources.requests))
206+
207+
accelerators_available = accelerator_count - allocated_qty
208+
209+
if accelerator_name not in total_accelerators_available:
210+
total_accelerators_available[accelerator_name] = 0
211+
if accelerators_available >= min_quantity_filter:
212+
quantized_availability = min_quantity_filter * (
213+
accelerators_available // min_quantity_filter)
214+
total_accelerators_available[
215+
accelerator_name] += quantized_availability
198216

199217
result = []
200218

sky/clouds/utils/gcp_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sky import sky_logging
1818
from sky import skypilot_config
1919
from sky.provision.gcp import constants
20+
from sky.provision.kubernetes import utils as kubernetes_utils
2021
from sky.utils import subprocess_utils
2122

2223
if typing.TYPE_CHECKING:
@@ -35,7 +36,10 @@ def is_tpu(resources: Optional['resources_lib.Resources']) -> bool:
3536
def is_tpu_vm(resources: Optional['resources_lib.Resources']) -> bool:
3637
if not is_tpu(resources):
3738
return False
38-
assert resources is not None
39+
assert (resources is not None and len(resources.accelerators) == 1)
40+
acc, _ = list(resources.accelerators.items())[0]
41+
if kubernetes_utils.is_tpu_on_gke(acc):
42+
return False
3943
if resources.accelerator_args is None:
4044
return True
4145
return resources.accelerator_args.get('tpu_vm', True)

0 commit comments

Comments
 (0)