Skip to content

Commit

Permalink
Azure: update fetch_azure to support two H100 families. (#2844)
Browse files Browse the repository at this point in the history
* Azure: update fetch_azure to support two H100 families.

* format
  • Loading branch information
concretevitamin authored Dec 7, 2023
1 parent 84313ed commit 66b8635
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 22 deletions.
50 changes: 30 additions & 20 deletions sky/clouds/service_catalog/data_fetchers/fetch_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,34 @@

SINGLE_THREADED = False

# Family name to SkyPilot GPU name mapping.
#
# When adding a new accelerator:
# - The instance type is typically already fetched, but we need to find the
# family name and add it to this mapping.
# - To inspect family names returned by Azure API, check the dataframes in
# get_all_regions_instance_types_df().
FAMILY_NAME_TO_SKYPILOT_GPU_NAME = {
'standardNCFamily': 'K80',
'standardNCSv2Family': 'P100',
'standardNCSv3Family': 'V100',
'standardNCPromoFamily': 'K80',
'StandardNCASv3_T4Family': 'T4',
'standardNDSv2Family': 'V100-32GB',
'StandardNCADSA100v4Family': 'A100-80GB',
'standardNDAMSv4_A100Family': 'A100-80GB',
'StandardNDASv4_A100Family': 'A100',
'standardNVFamily': 'M60',
'standardNVSv2Family': 'M60',
'standardNVSv3Family': 'M60',
'standardNVPromoFamily': 'M60',
'standardNVSv4Family': 'Radeon MI25',
'standardNDSFamily': 'P40',
'StandardNVADSA10v5Family': 'A10',
'StandardNCadsH100v5Family': 'H100',
'standardNDSH100v5Family': 'H100',
}


def get_regions() -> List[str]:
"""Get all available regions."""
Expand Down Expand Up @@ -78,7 +106,7 @@ def get_pricing_url(region: Optional[str] = None) -> str:
def get_pricing_df(region: Optional[str] = None) -> pd.DataFrame:
all_items = []
url = get_pricing_url(region)
print(f'Getting pricing for {region}')
print(f'Getting pricing for {region}, url: {url}')
page = 0
while url is not None:
page += 1
Expand Down Expand Up @@ -125,29 +153,11 @@ def get_sku_df(region_set: Set[str]) -> pd.DataFrame:


def get_gpu_name(family: str) -> Optional[str]:
gpu_data = {
'standardNCFamily': 'K80',
'standardNCSv2Family': 'P100',
'standardNCSv3Family': 'V100',
'standardNCPromoFamily': 'K80',
'StandardNCASv3_T4Family': 'T4',
'standardNDSv2Family': 'V100-32GB',
'StandardNCADSA100v4Family': 'A100-80GB',
'standardNDAMSv4_A100Family': 'A100-80GB',
'StandardNDASv4_A100Family': 'A100',
'standardNVFamily': 'M60',
'standardNVSv2Family': 'M60',
'standardNVSv3Family': 'M60',
'standardNVPromoFamily': 'M60',
'standardNVSv4Family': 'Radeon MI25',
'standardNDSFamily': 'P40',
'StandardNVADSA10v5Family': 'A10',
}
# NP-series offer Xilinx U250 FPGAs which are not GPUs,
# so we do not include them here.
# https://docs.microsoft.com/en-us/azure/virtual-machines/np-series
family = family.replace(' ', '')
return gpu_data.get(family)
return FAMILY_NAME_TO_SKYPILOT_GPU_NAME.get(family)


def get_all_regions_instance_types_df(region_set: Set[str]):
Expand Down
10 changes: 8 additions & 2 deletions sky/utils/accelerator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
# NOTE: Must include accelerators supported for local clusters.
#
# 1. What if a name is in this list, but not in any catalog?
#
# The name will be canonicalized, but the accelerator will not be supported.
# Optimizer will print an error message.
#
# 2. What if a name is not in this list, but in a catalog?
#
# The list is simply an optimization to short-circuit the search in the catalog.
# If the name is not found in the list, it will be searched in the catalog
# with its case being ignored. If a match is found, the name will be
# canonicalized to that in the catalog. Note that this lookup can be an
# expensive operation, as it requires reading the catalog or making external
# API calls (such as for Kubernetes). Thus it is desirable to keep this list
# up-to-date with commonly used accelerators.

# 3. (For SkyPilot dev) What to do if I want to add a new accelerator?
#
# Append its case-sensitive canonical name to this list. The name must match
# `AcceleratorName` in the service catalog, or what we define in
# `onprem_utils.get_local_cluster_accelerators`.
Expand All @@ -42,6 +47,7 @@
'Radeon MI25',
'P4',
'L4',
'H100',
]


Expand Down Expand Up @@ -72,11 +78,11 @@ def canonicalize_accelerator_name(accelerator: str) -> str:
if len(names) == 1:
return names[0]

# Do not print an error meessage here. Optimizer will handle it.
# Do not print an error message here. Optimizer will handle it.
if len(names) == 0:
return accelerator

# Currenlty unreachable.
# Currently unreachable.
# This can happen if catalogs have the same accelerator with
# different names (e.g., A10g and A10G).
assert len(names) > 1
Expand Down

0 comments on commit 66b8635

Please sign in to comment.