diff --git a/src/gpuhunt/_internal/catalog.py b/src/gpuhunt/_internal/catalog.py index 4f7fe52..2174652 100644 --- a/src/gpuhunt/_internal/catalog.py +++ b/src/gpuhunt/_internal/catalog.py @@ -3,11 +3,12 @@ import heapq import io import logging +import os import time import urllib.request import zipfile -from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor, wait +from pathlib import Path from typing import Optional, Union import gpuhunt._internal.constraints as constraints @@ -190,22 +191,27 @@ def _get_offline_provider_items( self, provider_name: str, query_filter: QueryFilter ) -> list[CatalogItem]: logger.debug("Loading items for offline provider %s", provider_name) - items = [] - - if self.catalog is None: - logger.warning("Catalog not loaded") - return items - - with zipfile.ZipFile(self.catalog) as zip_file: - with zip_file.open(f"{provider_name}.csv", "r") as csv_file: - reader: Iterable[dict[str, str]] = csv.DictReader( - io.TextIOWrapper(csv_file, "utf-8") - ) + # Set this env var to use a local catalog instead of the s3 catalog + catalog_dir = os.getenv("GPUHUNT_CATALOG_DIR") + if catalog_dir is not None: + with open(Path(catalog_dir) / f"{provider_name}.csv", "rb") as csv_file: + reader = csv.DictReader(io.TextIOWrapper(csv_file, "utf-8")) for row in reader: item = CatalogItem.from_dict(row, provider=provider_name) if constraints.matches(item, query_filter): items.append(item) + else: + if self.catalog is None: + logger.warning("Catalog not loaded") + return items + with zipfile.ZipFile(self.catalog) as zip_file: + with zip_file.open(f"{provider_name}.csv", "r") as csv_file: + reader = csv.DictReader(io.TextIOWrapper(csv_file, "utf-8")) + for row in reader: + item = CatalogItem.from_dict(row, provider=provider_name) + if constraints.matches(item, query_filter): + items.append(item) return items def _get_online_provider_items( diff --git a/src/gpuhunt/_internal/constraints.py b/src/gpuhunt/_internal/constraints.py index 2cbefef..7cad6de 100644 --- a/src/gpuhunt/_internal/constraints.py +++ b/src/gpuhunt/_internal/constraints.py @@ -11,7 +11,7 @@ ) # v5litepod = v5e -_TPU_VERSIONS = ["v2", "v3", "v4", "v5p", "v5litepod"] +_TPU_VERSIONS = ["v2", "v3", "v4", "v5p", "v5litepod", "v6e"] def _is_tpu(name: str) -> bool: diff --git a/src/gpuhunt/providers/gcp.py b/src/gpuhunt/providers/gcp.py index 52b4988..b2e7570 100644 --- a/src/gpuhunt/providers/gcp.py +++ b/src/gpuhunt/providers/gcp.py @@ -1,5 +1,5 @@ import copy -import importlib +import importlib.resources import json import logging import re @@ -13,7 +13,6 @@ from google.cloud.billing_v1 import CloudCatalogClient, ListSkusRequest from google.cloud.billing_v1.types.cloud_catalog import Sku from google.cloud.location import locations_pb2 -from google.cloud.location.locations_pb2 import ListLocationsResponse from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem from gpuhunt.providers import AbstractProvider @@ -90,15 +89,17 @@ } -def load_tpu_pricing(): - resource_package = "gpuhunt.resources" - resource_name = "tpu_pricing.json" - - with importlib.resources.open_text(resource_package, resource_name) as f: - return json.load(f) +def load_tpu_pricing() -> dict: + return json.loads( + importlib.resources.files("gpuhunt.resources").joinpath("tpu_pricing.json").read_text() + ) -tpu_pricing: dict = load_tpu_pricing() +# A manually filled TPU pricing table from the pricing page: +# https://cloud.google.com/tpu/pricing?hl=en. +# It's needed since the TPU pricing API does not return prices for all regions. +# The API may also return 1-year Commitment prices instead of on-demand prices. +TPU_PRICING_TABLE = load_tpu_pricing() class GCPProvider(AbstractProvider): @@ -218,7 +219,6 @@ def get( self.add_gpus(instances) offers = self.fill_prices(instances) self.fill_gpu_vendors_and_names(offers) - # Add tpu offerings offers.extend(get_tpu_offers(self.project)) return sorted(offers, key=lambda i: i.price) @@ -354,9 +354,10 @@ def get_vm_family(instance_name: str) -> str: def get_tpu_offers(project_id: str) -> list[RawCatalogItem]: - logger.info("Fetching tpu offers") + logger.info("Fetching TPU offers") raw_catalog_items: list[RawCatalogItem] = [] catalog_items: list[dict] = get_catalog_items(project_id) + # For some TPU offers in some regions, GCP does not list prices at all. Skip such offers. filtered_catalog_items = [item for item in catalog_items if item["price"] is not None] for item in filtered_catalog_items: on_demand_item = RawCatalogItem( @@ -382,6 +383,15 @@ def get_tpu_offers(project_id: str) -> list[RawCatalogItem]: def get_catalog_items(project_id: str) -> list[dict]: + """ + Returns TPU configurations with pricing info. + Each configuration contains on-demand price and spot price but any price can be missing. + This is because the API does not return prices for all regions. + As a backup, the prices are taken from the pricing table on the GCP website, + but it also does not contain all the prices. + Even when creating some TPUs in some regions via the GCP console, + the price is not shown (e.g. v6e in us-south1). + """ tpu_prices: list[dict] = get_tpu_prices() configs: list[dict] = get_tpu_configs(project_id) for config in configs: @@ -390,20 +400,19 @@ def get_catalog_items(project_id: str) -> list[dict]: no_of_chips = config["no_of_chips"] tpu_version, no_of_cores = instance_name.rsplit("-", 1) no_of_cores = int(no_of_cores) - if tpu_version in ["v5litepod", "v5p"]: - # For TPU-v5 series, api provides per chip price. - # Verify per chip price in the following link.https://cloud.google.com/tpu/pricing + if tpu_version in ["v5litepod", "v5p", "v6e"]: + # For TPU-v5 series, the API provides per chip price. is_pod = True on_demand_base_price = find_base_price_v5( tpu_version, location, tpu_prices, spot=False ) - spot_base_price = find_base_price_v5(tpu_version, location, tpu_prices, spot=True) if on_demand_base_price is not None: on_demand_price = on_demand_base_price * no_of_chips else: on_demand_price = find_tpu_price_static_src( tpu_version, no_of_cores, location, no_of_chips, False ) + spot_base_price = find_base_price_v5(tpu_version, location, tpu_prices, spot=True) if spot_base_price is not None: spot_price = spot_base_price * no_of_chips else: @@ -412,7 +421,7 @@ def get_catalog_items(project_id: str) -> list[dict]: ) elif tpu_version in ["v2", "v3", "v4"]: # For TPU-v2 and TPU-v3, the pricing API provides the prices of 8 TPU cores. - # For TPU-v4, api only provides the price of TPU-v4 pods. + # For TPU-v4, the API provides the price of TPU-v4 pods. if no_of_cores > 8 or tpu_version == "v4": is_pod = True base_instance_name = f"{tpu_version}-8" @@ -420,37 +429,39 @@ def get_catalog_items(project_id: str) -> list[dict]: on_demand_base_price = find_base_price( tpu_version, location, tpu_prices, spot=False, is_pod=True ) - spot_base_price = find_base_price( - tpu_version, location, tpu_prices, spot=True, is_pod=True - ) - if on_demand_base_price is not None and base_no_of_chips is not None: on_demand_price = (on_demand_base_price / base_no_of_chips) * no_of_chips else: on_demand_price = find_tpu_price_static_src( tpu_version, no_of_cores, location, no_of_chips, False ) + spot_base_price = find_base_price( + tpu_version, location, tpu_prices, spot=True, is_pod=True + ) if spot_base_price is not None and base_no_of_chips is not None: spot_price = (spot_base_price / base_no_of_chips) * no_of_chips else: spot_price = find_tpu_price_static_src( tpu_version, no_of_cores, location, no_of_chips, True ) - elif no_of_cores == 8: is_pod = False + base_no_of_chips = no_of_chips on_demand_base_price = find_base_price( tpu_version, location, tpu_prices, spot=False, is_pod=False ) + on_demand_price = on_demand_base_price spot_base_price = find_base_price( tpu_version, location, tpu_prices, spot=True, is_pod=False ) - on_demand_price = ( - on_demand_base_price if on_demand_base_price is not None else None - ) - spot_price = spot_base_price if spot_base_price is not None else None - base_no_of_chips = no_of_chips - + spot_price = spot_base_price + else: + logger.warning("Unknown TPU version %s. Skipping offer.", tpu_version) + continue + if on_demand_price is None: + logger.debug("Failed to find on-demand price for %s in %s", instance_name, location) + if spot_price is None: + logger.debug("Failed to find spot price for %s in %s", instance_name, location) config["price"] = on_demand_price config["spot"] = spot_price config["is_pod"] = is_pod @@ -464,23 +475,16 @@ def get_tpu_prices() -> list[dict]: client = CloudCatalogClient() tpu_configs = [] # E000-3F24-B8AA contains prices for TPU versions v2,v3,v4. - # 6F81-5844-456A contains prices for TPU versions v5p and v5litepod(v5e) + # 6F81-5844-456A contains prices for newer TPU versions v5p, v5litepod(v5e), v6e. service_names = ["services/E000-3F24-B8AA", "services/6F81-5844-456A"] - - # Loop through each service name and list SKUs for service_name in service_names: - # Create the request request = ListSkusRequest(parent=service_name) - - # List SKUs response = client.list_skus(request=request) - for sku in response.skus: if sku.category.resource_group != "TPU": continue if sku.category.usage_type not in ["OnDemand", "Preemptible"]: continue - tpu_version = extract_tpu_version(sku.description) if tpu_version: is_pod = True if "Pod" in sku.description else False @@ -524,21 +528,23 @@ def find_no_of_chips(instance_name: str, configs: list[dict]): def find_tpu_price_static_src( tpu_version: str, num_cores: int, tpu_region: str, no_of_chips: int, spot: bool ) -> Optional[float]: - # Pricing table names v5litepod as v5e + # The pricing page names v5litepod as v5e tpu_version = "v5e" if tpu_version == "v5litepod" else tpu_version - is_pod = num_cores > 8 or tpu_version == "v4" - tpu_type = f"TPU {tpu_version} pod" if is_pod else f"TPU {tpu_version} device" - # Comment here - if tpu_version == "v5p" or tpu_version == "v5e": - tpu_type = f"TPU {tpu_version}" + # The pricing page lists different (device and pod) prices per chip for v2 and v3. + # The device is the smallest configuration with 8 TPU cores (e.g. v3-8). + # TPU Pod connects mulitple TPU devices (e.g. v3-32). + # Not applicable for newer generations. + tpu_type = f"TPU {tpu_version}" + if tpu_version in ["v2", "v3", "v4"]: + is_pod = num_cores > 8 or tpu_version == "v4" + tpu_type = f"TPU {tpu_version} pod" if is_pod else f"TPU {tpu_version} device" + price_key = "On Demand (USD)" + if spot: + price_key = "Spot (USD)" try: - on_demand_price = tpu_pricing[tpu_type][tpu_region]["On Demand (USD)"] * no_of_chips - spot_price = tpu_pricing[tpu_type][tpu_region]["Spot (USD)"] * no_of_chips - return on_demand_price if not spot else spot_price + return TPU_PRICING_TABLE[tpu_type][tpu_region][price_key] * no_of_chips except KeyError: - logger.warning( - f'key error for {tpu_type} {tpu_region} {"On Demand (USD)" if spot else "Spot (USD)"}' - ) + logger.debug(f"KeyError for {tpu_type} {tpu_region} {price_key}") return None @@ -563,7 +569,6 @@ def get_tpu_configs(project_id: str) -> list[dict]: request = tpu_v2.ListAcceleratorTypesRequest( parent=parent, ) - # request = tpu_v1.ListAcceleratorTypesRequest(parent=parent) page_result = client.list_accelerator_types(request=request) for response in page_result: no_of_chips = get_no_of_chips(response.accelerator_configs[0].topology) @@ -579,11 +584,8 @@ def get_tpu_configs(project_id: str) -> list[dict]: def get_no_of_chips(expression: str) -> int: - # Split the expression by 'x' factors = expression.split("x") - # Convert each factor to an integer factors = map(int, factors) - # Calculate the product by multiplying all factors product = 1 for factor in factors: product *= factor @@ -592,11 +594,8 @@ def get_no_of_chips(expression: str) -> int: def get_locations(project_id: str) -> list[str]: client = tpu_v2.TpuClient() - # Initialize request argument(s) parent = f"projects/{project_id}" - list_locations_request: ListLocationsResponse = client.list_locations( - locations_pb2.ListLocationsRequest(name=parent) - ) + list_locations_request = client.list_locations(locations_pb2.ListLocationsRequest(name=parent)) locations = [loc.location_id for loc in list_locations_request.locations] # TPU V4 only available in location us-central2-b only. # us-central2-b needs to be enabled in the project. @@ -606,7 +605,6 @@ def get_locations(project_id: str) -> list[str]: def extract_tpu_version(input_string: str) -> Optional[str]: # The regular expression pattern to find a substring starting with 'Tpu' pattern = r"\bTpu[-\w]*\b" - # Search for the first match of the pattern match = re.search(pattern, input_string, re.IGNORECASE) if match: @@ -618,5 +616,4 @@ def extract_tpu_version(input_string: str) -> Optional[str]: # Name of v5e in gcp console is v5litepod version = "v5litepod" if version_match.group() == "v5e" else version_match.group() return version - return None diff --git a/src/gpuhunt/resources/tpu_pricing.json b/src/gpuhunt/resources/tpu_pricing.json index 09510c1..d70e3cf 100644 --- a/src/gpuhunt/resources/tpu_pricing.json +++ b/src/gpuhunt/resources/tpu_pricing.json @@ -1,4 +1,30 @@ { + "TPU v6e": { + "us-east1": { + "Location": "South Carolina", + "On Demand (USD)": 2.7000, + "1-year Commitment (USD)": 1.8900, + "3-year Commitment (USD)": 1.2200 + }, + "us-east5": { + "Location": "Ohio", + "On Demand (USD)": 2.7000, + "1-year Commitment (USD)": 1.8900, + "3-year Commitment (USD)": 1.2200 + }, + "europe-west4": { + "Location": "Netherlands", + "On Demand (USD)": 2.9700, + "1-year Commitment (USD)": 2.0800, + "3-year Commitment (USD)": 1.3400 + }, + "asia-northeast1": { + "Location": "Tokio", + "On Demand (USD)": 3.2400, + "1-year Commitment (USD)": 2.2700, + "3-year Commitment (USD)": 1.4600 + } + }, "TPU v5p": { "us-east5": { "Location": "Ohio",