Skip to content

Commit

Permalink
[Core] Fix A10 GPU on Azure (#3707)
Browse files Browse the repository at this point in the history
* init

* works. todo: only do this for A10 VMs

* only install for A10 instances

* merge into one template

* Update sky/skylet/providers/azure/node_provider.py

Co-authored-by: Zhanghao Wu <[email protected]>

* add warning

* apply suggestions from code review

* Update sky/clouds/azure.py

Co-authored-by: Zhanghao Wu <[email protected]>

---------

Co-authored-by: Zhanghao Wu <[email protected]>
  • Loading branch information
cblmemo and Michaelvll authored Jul 5, 2024
1 parent 05ce5e9 commit 994d35a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 8 deletions.
8 changes: 8 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,8 +2020,16 @@ def provision_with_retries(
failover_history: List[Exception] = list()

style = colorama.Style
fore = colorama.Fore
# Retrying launchable resources.
while True:
if (isinstance(to_provision.cloud, clouds.Azure) and
to_provision.accelerators is not None and
'A10' in to_provision.accelerators):
logger.warning(f'{style.BRIGHT}{fore.YELLOW}Trying to launch '
'an A10 cluster on Azure. This may take ~20 '
'minutes due to driver installation.'
f'{style.RESET_ALL}')
try:
# Recheck cluster name as the 'except:' block below may
# change the cloud assignment.
Expand Down
20 changes: 12 additions & 8 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import subprocess
import textwrap
import typing
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

import colorama

Expand Down Expand Up @@ -269,13 +269,12 @@ def get_vcpus_mem_from_instance_type(
def get_zone_shell_cmd(cls) -> Optional[str]:
return None

def make_deploy_resources_variables(
self,
resources: 'resources.Resources',
cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Optional[str]]:
def make_deploy_resources_variables(self,
resources: 'resources.Resources',
cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Any]:
assert zones is None, ('Azure does not support zones', zones)

region_name = region.name
Expand Down Expand Up @@ -315,6 +314,10 @@ def make_deploy_resources_variables(
'image_version': version,
}

# Setup the A10 nvidia driver.
need_nvidia_driver_extension = (acc_dict is not None and
'A10' in acc_dict)

# Setup commands to eliminate the banner and restart sshd.
# This script will modify /etc/ssh/sshd_config and add a bash script
# into .bashrc. The bash script will restart sshd if it has not been
Expand Down Expand Up @@ -367,6 +370,7 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]:
# Azure does not support specific zones.
'zones': None,
**image_config,
'need_nvidia_driver_extension': need_nvidia_driver_extension,
'disk_tier': Azure._get_disk_type(_failover_disk_tier()),
'cloud_init_setup_commands': cloud_init_setup_commands,
'azure_subscription_id': self.get_project_id(dryrun),
Expand Down
27 changes: 27 additions & 0 deletions sky/skylet/providers/azure/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,33 @@ def _create_node(self, node_config, tags, count):
template_params["nsg"] = self.provider_config["nsg"]
template_params["subnet"] = self.provider_config["subnet"]

if node_config.get("need_nvidia_driver_extension", False):
# Configure driver extension for A10 GPUs. A10 GPUs requires a
# special type of drivers which is available at Microsoft HPC
# extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2
for r in template["resources"]:
if r["type"] == "Microsoft.Compute/virtualMachines":
# Add a nested extension resource for A10 GPUs
r["resources"] = [
{
"type": "extensions",
"apiVersion": "2015-06-15",
"location": "[variables('location')]",
"dependsOn": [
"[concat('Microsoft.Compute/virtualMachines/', parameters('vmName'), copyIndex())]"
],
"name": "NvidiaGpuDriverLinux",
"properties": {
"publisher": "Microsoft.HpcCompute",
"type": "NvidiaGpuDriverLinux",
"typeHandlerVersion": "1.9",
"autoUpgradeMinorVersion": True,
"settings": {},
},
},
]
break

parameters = {
"properties": {
"mode": DeploymentMode.incremental,
Expand Down
2 changes: 2 additions & 0 deletions sky/templates/azure-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ available_node_types:
# billingProfile:
# maxPrice: -1
{%- endif %}
need_nvidia_driver_extension: {{need_nvidia_driver_extension}}
# TODO: attach disk
{% if num_nodes > 1 %}
ray.worker.default:
Expand Down Expand Up @@ -108,6 +109,7 @@ available_node_types:
# billingProfile:
# maxPrice: -1
{%- endif %}
need_nvidia_driver_extension: {{need_nvidia_driver_extension}}
{%- endif %}

head_node_type: ray.head.default
Expand Down

0 comments on commit 994d35a

Please sign in to comment.