Skip to content

Commit

Permalink
[k8s] On-demand single-host TPU support on GKE (#3947)
Browse files Browse the repository at this point in the history
* initial version of TPU support on GKE

* revert unnecesary change

* revert

* use TPU_LABEL_KEY constant

* nit

* nit

* update detect_gpu_label_formatter() to use match_label_key()

* tidy get_gpu_label_key_value

* nit

* update method name

* update get_gke_accelerator_name to support TPU

* add support for get_label_keys method due to TPU label key

* syntax

* update get_tpu_topology_label_key_value

* nit

* refactor error surfacing methods to have it work with TPU support

* update toleration comment

* support listing available TPUs and show-gpus for TPUs

* nit

* update help message

* Update /tmp/tpu_logs dir's write permission

* nit

* nit

* comment update on TPU resource lackage error handling

* Update to use global constant instead of hard coded string of nvidia.com/gpu and google.com/tpu

* add smoke test and make exec work on TPU pods

* update smoke test to check if TPU is reachable.

* add comment

* nit

* Comment on number of requested TPU chips for multi- and single- host TPU slice.

* update method to check GKE supported TPU name

* nit

* move is_tpu_pod_slice to kubernetes_utils

* update get_accelerator_from_label_value to use is_tpu_pod_slice method

* nit

* format

* nit

* check acc count support

* preemptive TPU check

* update check_tpu_fits

* error msg update

* merge get_tpu_topology_label_key_value into get_gpu_label_key_value

* Update sky/provision/kubernetes/utils.py

Co-authored-by: Tian Xia <[email protected]>

* nit fixes

* format

* nit

* Implement method for reading acc counts from node/pod object

* assertion update for is_tpu_vm

* Exclude multi-host TPUs to displayed from show-gpus

* Notify users that multi-host TPUs are not supported from 'sky show-gpus'

* format

* nit

* display warning message from show-gpus conditionally

* update sky show-gpus

* update get_accelerator_label_key_value

* format

* format

* format

* update comment

* resolve review comments

* update tpuvm_mnist.yaml

* resolve comments

* update display message for show-gpus

* format

---------

Co-authored-by: Tian Xia <[email protected]>
  • Loading branch information
landscapepainter and cblmemo authored Nov 13, 2024
1 parent 2030398 commit eea13cc
Show file tree
Hide file tree
Showing 12 changed files with 691 additions and 267 deletions.
4 changes: 2 additions & 2 deletions examples/tpu/tpuvm_mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ resources:

# The setup command. Will be run under the working directory.
setup: |
git clone https://github.com/google/flax.git --branch v0.8.2
git clone https://github.com/google/flax.git --branch v0.10.1
conda activate flax
if [ $? -eq 0 ]; then
Expand All @@ -15,7 +15,7 @@ setup: |
conda activate flax
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
pip install \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.25" \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.35" \
clu \
tensorflow tensorflow-datasets
pip install -e flax
Expand Down
24 changes: 18 additions & 6 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3143,7 +3143,8 @@ def _get_kubernetes_realtime_gpu_table(
'in Kubernetes cluster. ')
debug_msg = ('To show available accelerators on kubernetes,'
' run: sky show-gpus --cloud kubernetes ')
full_err_msg = (err_msg + kubernetes_utils.NO_GPU_HELP_MESSAGE +
full_err_msg = (err_msg +
kubernetes_utils.NO_ACCELERATOR_HELP_MESSAGE +
debug_msg)
raise ValueError(full_err_msg)
for gpu, _ in sorted(counts.items()):
Expand All @@ -3161,11 +3162,12 @@ def _get_kubernetes_node_info_table(context: Optional[str]):

node_info_dict = kubernetes_utils.get_kubernetes_node_info(context)
for node_name, node_info in node_info_dict.items():
available = node_info.free['nvidia.com/gpu'] if node_info.free[
'nvidia.com/gpu'] != -1 else no_permissions_str
available = node_info.free[
'accelerators_available'] if node_info.free[
'accelerators_available'] != -1 else no_permissions_str
node_table.add_row([
node_name, node_info.gpu_type,
node_info.total['nvidia.com/gpu'], available
node_name, node_info.accelerator_type,
node_info.total['accelerator_count'], available
])
return node_table

Expand Down Expand Up @@ -3220,8 +3222,18 @@ def _output():
yield from k8s_realtime_table.get_string()
k8s_node_table = _get_kubernetes_node_info_table(context)
yield '\n\n'
# TODO(Doyoung): Update the message with the multi-host TPU
# support.
k8s_per_node_acc_message = (
'Kubernetes per node accelerator availability ')
if kubernetes_utils.multi_host_tpu_exists_in_cluster(
context):
k8s_per_node_acc_message += (
'(Note: Multi-host TPUs are detected and excluded '
'from the display as multi-host TPUs are not '
'supported.)')
yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Kubernetes per node GPU availability'
f'{k8s_per_node_acc_message}'
f'{colorama.Style.RESET_ALL}\n')
yield from k8s_node_table.get_string()
if kubernetes_autoscaling:
Expand Down
22 changes: 19 additions & 3 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,23 @@ def make_deploy_resources_variables(

k8s_acc_label_key = None
k8s_acc_label_value = None
k8s_topology_label_key = None
k8s_topology_label_value = None
k8s_resource_key = None
tpu_requested = False

# If GPUs are requested, set node label to match the GPU type.
# If GPU/TPUs are requested, set node label to match the GPU/TPU type.
if acc_count > 0 and acc_type is not None:
k8s_acc_label_key, k8s_acc_label_value = \
kubernetes_utils.get_gpu_label_key_value(context, acc_type)
(k8s_acc_label_key, k8s_acc_label_value, k8s_topology_label_key,
k8s_topology_label_value) = (
kubernetes_utils.get_accelerator_label_key_value(
context, acc_type, acc_count))
if (k8s_acc_label_key ==
kubernetes_utils.GKELabelFormatter.TPU_LABEL_KEY):
tpu_requested = True
k8s_resource_key = kubernetes_utils.TPU_RESOURCE_KEY
else:
k8s_resource_key = kubernetes_utils.GPU_RESOURCE_KEY

port_mode = network_utils.get_port_mode(None)

Expand Down Expand Up @@ -428,6 +440,10 @@ def make_deploy_resources_variables(
'k8s_skypilot_system_namespace': _SKYPILOT_SYSTEM_NAMESPACE,
'k8s_spot_label_key': spot_label_key,
'k8s_spot_label_value': spot_label_value,
'tpu_requested': tpu_requested,
'k8s_topology_label_key': k8s_topology_label_key,
'k8s_topology_label_value': k8s_topology_label_value,
'k8s_resource_key': k8s_resource_key,
'image_id': image_id,
}

Expand Down
148 changes: 83 additions & 65 deletions sky/clouds/service_catalog/kubernetes_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ def list_accelerators_realtime(
) or not kubernetes_utils.check_credentials(context)[0]:
return {}, {}, {}

has_gpu = kubernetes_utils.detect_gpu_resource(context)
has_gpu = kubernetes_utils.detect_accelerator_resource(context)
if not has_gpu:
return {}, {}, {}

label_formatter, _ = kubernetes_utils.detect_gpu_label_formatter(context)
if not label_formatter:
lf, _ = kubernetes_utils.detect_gpu_label_formatter(context)
if not lf:
return {}, {}, {}

accelerators_qtys: Set[Tuple[str, int]] = set()
key = label_formatter.get_label_key()
keys = lf.get_label_keys()
nodes = kubernetes_utils.get_kubernetes_nodes(context)
# Get the pods to get the real-time GPU usage
try:
Expand All @@ -134,67 +134,85 @@ def list_accelerators_realtime(
min_quantity_filter = quantity_filter if quantity_filter else 1

for node in nodes:
if key in node.metadata.labels:
allocated_qty = 0
accelerator_name = label_formatter.get_accelerator_from_label_value(
node.metadata.labels.get(key))

# Check if name_filter regex matches the accelerator_name
regex_flags = 0 if case_sensitive else re.IGNORECASE
if name_filter and not re.match(
name_filter, accelerator_name, flags=regex_flags):
continue

accelerator_count = int(
node.status.allocatable.get('nvidia.com/gpu', 0))

# Generate the GPU quantities for the accelerators
if accelerator_name and accelerator_count > 0:
count = 1
while count <= accelerator_count:
accelerators_qtys.add((accelerator_name, count))
count *= 2
# Add the accelerator count if it's not already in the set
# (e.g., if there's 12 GPUs, we should have qtys 1, 2, 4, 8, 12)
if accelerator_count not in accelerators_qtys:
accelerators_qtys.add((accelerator_name, accelerator_count))

if accelerator_count >= min_quantity_filter:
quantized_count = (min_quantity_filter *
(accelerator_count // min_quantity_filter))
if accelerator_name not in total_accelerators_capacity:
total_accelerators_capacity[
accelerator_name] = quantized_count
else:
total_accelerators_capacity[
accelerator_name] += quantized_count

if pods is None:
# If we can't get the pods, we can't get the GPU usage
total_accelerators_available[accelerator_name] = -1
continue

for pod in pods:
# Get all the pods running on the node
if (pod.spec.node_name == node.metadata.name and
pod.status.phase in ['Running', 'Pending']):
# Iterate over all the containers in the pod and sum the
# GPU requests
for container in pod.spec.containers:
if container.resources.requests:
allocated_qty += int(
container.resources.requests.get(
'nvidia.com/gpu', 0))

accelerators_available = accelerator_count - allocated_qty

if accelerator_name not in total_accelerators_available:
total_accelerators_available[accelerator_name] = 0
if accelerators_available >= min_quantity_filter:
quantized_availability = min_quantity_filter * (
accelerators_available // min_quantity_filter)
total_accelerators_available[
accelerator_name] += quantized_availability
for key in keys:
if key in node.metadata.labels:
allocated_qty = 0
accelerator_name = lf.get_accelerator_from_label_value(
node.metadata.labels.get(key))

# Exclude multi-host TPUs from being processed.
# TODO(Doyoung): Remove the logic when adding support for
# multi-host TPUs.
if kubernetes_utils.is_multi_host_tpu(node.metadata.labels):
continue

# Check if name_filter regex matches the accelerator_name
regex_flags = 0 if case_sensitive else re.IGNORECASE
if name_filter and not re.match(
name_filter, accelerator_name, flags=regex_flags):
continue

# Generate the accelerator quantities
accelerator_count = (
kubernetes_utils.get_node_accelerator_count(
node.status.allocatable))

if accelerator_name and accelerator_count > 0:
# TPUs are counted in a different way compared to GPUs.
# Multi-node GPUs can be split into smaller units and be
# provisioned, but TPUs are considered as an atomic unit.
if kubernetes_utils.is_tpu_on_gke(accelerator_name):
accelerators_qtys.add(
(accelerator_name, accelerator_count))
else:
count = 1
while count <= accelerator_count:
accelerators_qtys.add((accelerator_name, count))
count *= 2
# Add the accelerator count if it's not already in the
# set (e.g., if there's 12 GPUs, we should have qtys 1,
# 2, 4, 8, 12)
if accelerator_count not in accelerators_qtys:
accelerators_qtys.add(
(accelerator_name, accelerator_count))

if accelerator_count >= min_quantity_filter:
quantized_count = (
min_quantity_filter *
(accelerator_count // min_quantity_filter))
if accelerator_name not in total_accelerators_capacity:
total_accelerators_capacity[
accelerator_name] = quantized_count
else:
total_accelerators_capacity[
accelerator_name] += quantized_count

if pods is None:
# If we can't get the pods, we can't get the GPU usage
total_accelerators_available[accelerator_name] = -1
continue

for pod in pods:
# Get all the pods running on the node
if (pod.spec.node_name == node.metadata.name and
pod.status.phase in ['Running', 'Pending']):
# Iterate over all the containers in the pod and sum
# the GPU requests
for container in pod.spec.containers:
if container.resources.requests:
allocated_qty += (
kubernetes_utils.get_node_accelerator_count(
container.resources.requests))

accelerators_available = accelerator_count - allocated_qty

if accelerator_name not in total_accelerators_available:
total_accelerators_available[accelerator_name] = 0
if accelerators_available >= min_quantity_filter:
quantized_availability = min_quantity_filter * (
accelerators_available // min_quantity_filter)
total_accelerators_available[
accelerator_name] += quantized_availability

result = []

Expand Down
6 changes: 5 additions & 1 deletion sky/clouds/utils/gcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sky import sky_logging
from sky import skypilot_config
from sky.provision.gcp import constants
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.utils import subprocess_utils

if typing.TYPE_CHECKING:
Expand All @@ -35,7 +36,10 @@ def is_tpu(resources: Optional['resources_lib.Resources']) -> bool:
def is_tpu_vm(resources: Optional['resources_lib.Resources']) -> bool:
if not is_tpu(resources):
return False
assert resources is not None
assert (resources is not None and len(resources.accelerators) == 1)
acc, _ = list(resources.accelerators.items())[0]
if kubernetes_utils.is_tpu_on_gke(acc):
return False
if resources.accelerator_args is None:
return True
return resources.accelerator_args.get('tpu_vm', True)
Expand Down
Loading

0 comments on commit eea13cc

Please sign in to comment.