Skip to content

Commit

Permalink
apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Jul 4, 2024
1 parent 4ccedce commit 018181f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
20 changes: 12 additions & 8 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import subprocess
import textwrap
import typing
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

import colorama

Expand Down Expand Up @@ -269,13 +269,12 @@ def get_vcpus_mem_from_instance_type(
def get_zone_shell_cmd(cls) -> Optional[str]:
return None

def make_deploy_resources_variables(
self,
resources: 'resources.Resources',
cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Optional[str]]:
def make_deploy_resources_variables(self,
resources: 'resources.Resources',
cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Any]:
assert zones is None, ('Azure does not support zones', zones)

region_name = region.name
Expand Down Expand Up @@ -315,6 +314,10 @@ def make_deploy_resources_variables(
'image_version': version,
}

# Setup the A10 nvidia driver.
need_nvidia_driver_extension = (resources.accelerators is not None and
'A10' in resources.accelerators)

# Setup commands to eliminate the banner and restart sshd.
# This script will modify /etc/ssh/sshd_config and add a bash script
# into .bashrc. The bash script will restart sshd if it has not been
Expand Down Expand Up @@ -367,6 +370,7 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]:
# Azure does not support specific zones.
'zones': None,
**image_config,
'need_nvidia_driver_extension': need_nvidia_driver_extension,
'disk_tier': Azure._get_disk_type(_failover_disk_tier()),
'cloud_init_setup_commands': cloud_init_setup_commands,
'azure_subscription_id': self.get_project_id(dryrun),
Expand Down
7 changes: 1 addition & 6 deletions sky/skylet/providers/azure/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,7 @@ def _create_node(self, node_config, tags, count):
template_params["nsg"] = self.provider_config["nsg"]
template_params["subnet"] = self.provider_config["subnet"]

# pylint: disable=import-outside-toplevel
from sky.clouds.service_catalog import azure_catalog

instance_type = node_config["azure_arm_parameters"].get("vmSize", "")
accs = azure_catalog.get_accelerators_from_instance_type(instance_type)
if accs is not None and "A10" in accs:
if node_config.get("need_nvidia_driver_extension", False):
# Configure driver extension for A10 GPUs. A10 GPUs requires a
# special type of drivers which is available at Microsoft HPC
# extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2
Expand Down
2 changes: 2 additions & 0 deletions sky/templates/azure-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ available_node_types:
# billingProfile:
# maxPrice: -1
{%- endif %}
need_nvidia_driver_extension: {{need_nvidia_driver_extension}}
# TODO: attach disk
{% if num_nodes > 1 %}
ray.worker.default:
Expand Down Expand Up @@ -108,6 +109,7 @@ available_node_types:
# billingProfile:
# maxPrice: -1
{%- endif %}
need_nvidia_driver_extension: {{need_nvidia_driver_extension}}
{%- endif %}

head_node_type: ray.head.default
Expand Down

0 comments on commit 018181f

Please sign in to comment.