diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index b75f9207856..19fca673977 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -7,7 +7,7 @@ import subprocess import textwrap import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import colorama @@ -269,13 +269,12 @@ def get_vcpus_mem_from_instance_type( def get_zone_shell_cmd(cls) -> Optional[str]: return None - def make_deploy_resources_variables( - self, - resources: 'resources.Resources', - cluster_name_on_cloud: str, - region: 'clouds.Region', - zones: Optional[List['clouds.Zone']], - dryrun: bool = False) -> Dict[str, Optional[str]]: + def make_deploy_resources_variables(self, + resources: 'resources.Resources', + cluster_name_on_cloud: str, + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], + dryrun: bool = False) -> Dict[str, Any]: assert zones is None, ('Azure does not support zones', zones) region_name = region.name @@ -315,6 +314,10 @@ def make_deploy_resources_variables( 'image_version': version, } + # Setup the A10 nvidia driver. + need_nvidia_driver_extension = (resources.accelerators is not None and + 'A10' in resources.accelerators) + # Setup commands to eliminate the banner and restart sshd. # This script will modify /etc/ssh/sshd_config and add a bash script # into .bashrc. The bash script will restart sshd if it has not been @@ -367,6 +370,7 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]: # Azure does not support specific zones. 'zones': None, **image_config, + 'need_nvidia_driver_extension': need_nvidia_driver_extension, 'disk_tier': Azure._get_disk_type(_failover_disk_tier()), 'cloud_init_setup_commands': cloud_init_setup_commands, 'azure_subscription_id': self.get_project_id(dryrun), diff --git a/sky/skylet/providers/azure/node_provider.py b/sky/skylet/providers/azure/node_provider.py index 03b2b3beb6b..5f87e57245e 100644 --- a/sky/skylet/providers/azure/node_provider.py +++ b/sky/skylet/providers/azure/node_provider.py @@ -303,12 +303,7 @@ def _create_node(self, node_config, tags, count): template_params["nsg"] = self.provider_config["nsg"] template_params["subnet"] = self.provider_config["subnet"] - # pylint: disable=import-outside-toplevel - from sky.clouds.service_catalog import azure_catalog - - instance_type = node_config["azure_arm_parameters"].get("vmSize", "") - accs = azure_catalog.get_accelerators_from_instance_type(instance_type) - if accs is not None and "A10" in accs: + if node_config.get("need_nvidia_driver_extension", False): # Configure driver extension for A10 GPUs. A10 GPUs requires a # special type of drivers which is available at Microsoft HPC # extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2 diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 66eac439453..e8c388e1879 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -80,6 +80,7 @@ available_node_types: # billingProfile: # maxPrice: -1 {%- endif %} + need_nvidia_driver_extension: {{need_nvidia_driver_extension}} # TODO: attach disk {% if num_nodes > 1 %} ray.worker.default: @@ -108,6 +109,7 @@ available_node_types: # billingProfile: # maxPrice: -1 {%- endif %} + need_nvidia_driver_extension: {{need_nvidia_driver_extension}} {%- endif %} head_node_type: ray.head.default