Skip to content

Commit

Permalink
Support TPU v6e (#111)
Browse files Browse the repository at this point in the history
* Support TPU v6e

* Introduce GPUHUNT_CATALOG_DIR env var for development
  • Loading branch information
r4victor authored Dec 20, 2024
1 parent cb61539 commit edc13c0
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 70 deletions.
30 changes: 18 additions & 12 deletions src/gpuhunt/_internal/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import heapq
import io
import logging
import os
import time
import urllib.request
import zipfile
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor, wait
from pathlib import Path
from typing import Optional, Union

import gpuhunt._internal.constraints as constraints
Expand Down Expand Up @@ -190,22 +191,27 @@ def _get_offline_provider_items(
self, provider_name: str, query_filter: QueryFilter
) -> list[CatalogItem]:
logger.debug("Loading items for offline provider %s", provider_name)

items = []

if self.catalog is None:
logger.warning("Catalog not loaded")
return items

with zipfile.ZipFile(self.catalog) as zip_file:
with zip_file.open(f"{provider_name}.csv", "r") as csv_file:
reader: Iterable[dict[str, str]] = csv.DictReader(
io.TextIOWrapper(csv_file, "utf-8")
)
# Set this env var to use a local catalog instead of the s3 catalog
catalog_dir = os.getenv("GPUHUNT_CATALOG_DIR")
if catalog_dir is not None:
with open(Path(catalog_dir) / f"{provider_name}.csv", "rb") as csv_file:
reader = csv.DictReader(io.TextIOWrapper(csv_file, "utf-8"))
for row in reader:
item = CatalogItem.from_dict(row, provider=provider_name)
if constraints.matches(item, query_filter):
items.append(item)
else:
if self.catalog is None:
logger.warning("Catalog not loaded")
return items
with zipfile.ZipFile(self.catalog) as zip_file:
with zip_file.open(f"{provider_name}.csv", "r") as csv_file:
reader = csv.DictReader(io.TextIOWrapper(csv_file, "utf-8"))
for row in reader:
item = CatalogItem.from_dict(row, provider=provider_name)
if constraints.matches(item, query_filter):
items.append(item)
return items

def _get_online_provider_items(
Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/_internal/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

# v5litepod = v5e
_TPU_VERSIONS = ["v2", "v3", "v4", "v5p", "v5litepod"]
_TPU_VERSIONS = ["v2", "v3", "v4", "v5p", "v5litepod", "v6e"]


def _is_tpu(name: str) -> bool:
Expand Down
111 changes: 54 additions & 57 deletions src/gpuhunt/providers/gcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
import importlib
import importlib.resources
import json
import logging
import re
Expand All @@ -13,7 +13,6 @@
from google.cloud.billing_v1 import CloudCatalogClient, ListSkusRequest
from google.cloud.billing_v1.types.cloud_catalog import Sku
from google.cloud.location import locations_pb2
from google.cloud.location.locations_pb2 import ListLocationsResponse

from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem
from gpuhunt.providers import AbstractProvider
Expand Down Expand Up @@ -90,15 +89,17 @@
}


def load_tpu_pricing():
resource_package = "gpuhunt.resources"
resource_name = "tpu_pricing.json"

with importlib.resources.open_text(resource_package, resource_name) as f:
return json.load(f)
def load_tpu_pricing() -> dict:
return json.loads(
importlib.resources.files("gpuhunt.resources").joinpath("tpu_pricing.json").read_text()
)


tpu_pricing: dict = load_tpu_pricing()
# A manually filled TPU pricing table from the pricing page:
# https://cloud.google.com/tpu/pricing?hl=en.
# It's needed since the TPU pricing API does not return prices for all regions.
# The API may also return 1-year Commitment prices instead of on-demand prices.
TPU_PRICING_TABLE = load_tpu_pricing()


class GCPProvider(AbstractProvider):
Expand Down Expand Up @@ -218,7 +219,6 @@ def get(
self.add_gpus(instances)
offers = self.fill_prices(instances)
self.fill_gpu_vendors_and_names(offers)
# Add tpu offerings
offers.extend(get_tpu_offers(self.project))
return sorted(offers, key=lambda i: i.price)

Expand Down Expand Up @@ -354,9 +354,10 @@ def get_vm_family(instance_name: str) -> str:


def get_tpu_offers(project_id: str) -> list[RawCatalogItem]:
logger.info("Fetching tpu offers")
logger.info("Fetching TPU offers")
raw_catalog_items: list[RawCatalogItem] = []
catalog_items: list[dict] = get_catalog_items(project_id)
# For some TPU offers in some regions, GCP does not list prices at all. Skip such offers.
filtered_catalog_items = [item for item in catalog_items if item["price"] is not None]
for item in filtered_catalog_items:
on_demand_item = RawCatalogItem(
Expand All @@ -382,6 +383,15 @@ def get_tpu_offers(project_id: str) -> list[RawCatalogItem]:


def get_catalog_items(project_id: str) -> list[dict]:
"""
Returns TPU configurations with pricing info.
Each configuration contains on-demand price and spot price but any price can be missing.
This is because the API does not return prices for all regions.
As a backup, the prices are taken from the pricing table on the GCP website,
but it also does not contain all the prices.
Even when creating some TPUs in some regions via the GCP console,
the price is not shown (e.g. v6e in us-south1).
"""
tpu_prices: list[dict] = get_tpu_prices()
configs: list[dict] = get_tpu_configs(project_id)
for config in configs:
Expand All @@ -390,20 +400,19 @@ def get_catalog_items(project_id: str) -> list[dict]:
no_of_chips = config["no_of_chips"]
tpu_version, no_of_cores = instance_name.rsplit("-", 1)
no_of_cores = int(no_of_cores)
if tpu_version in ["v5litepod", "v5p"]:
# For TPU-v5 series, api provides per chip price.
# Verify per chip price in the following link.https://cloud.google.com/tpu/pricing
if tpu_version in ["v5litepod", "v5p", "v6e"]:
# For TPU-v5 series, the API provides per chip price.
is_pod = True
on_demand_base_price = find_base_price_v5(
tpu_version, location, tpu_prices, spot=False
)
spot_base_price = find_base_price_v5(tpu_version, location, tpu_prices, spot=True)
if on_demand_base_price is not None:
on_demand_price = on_demand_base_price * no_of_chips
else:
on_demand_price = find_tpu_price_static_src(
tpu_version, no_of_cores, location, no_of_chips, False
)
spot_base_price = find_base_price_v5(tpu_version, location, tpu_prices, spot=True)
if spot_base_price is not None:
spot_price = spot_base_price * no_of_chips
else:
Expand All @@ -412,45 +421,47 @@ def get_catalog_items(project_id: str) -> list[dict]:
)
elif tpu_version in ["v2", "v3", "v4"]:
# For TPU-v2 and TPU-v3, the pricing API provides the prices of 8 TPU cores.
# For TPU-v4, api only provides the price of TPU-v4 pods.
# For TPU-v4, the API provides the price of TPU-v4 pods.
if no_of_cores > 8 or tpu_version == "v4":
is_pod = True
base_instance_name = f"{tpu_version}-8"
base_no_of_chips = find_no_of_chips(base_instance_name, configs)
on_demand_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=False, is_pod=True
)
spot_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=True, is_pod=True
)

if on_demand_base_price is not None and base_no_of_chips is not None:
on_demand_price = (on_demand_base_price / base_no_of_chips) * no_of_chips
else:
on_demand_price = find_tpu_price_static_src(
tpu_version, no_of_cores, location, no_of_chips, False
)
spot_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=True, is_pod=True
)
if spot_base_price is not None and base_no_of_chips is not None:
spot_price = (spot_base_price / base_no_of_chips) * no_of_chips
else:
spot_price = find_tpu_price_static_src(
tpu_version, no_of_cores, location, no_of_chips, True
)

elif no_of_cores == 8:
is_pod = False
base_no_of_chips = no_of_chips
on_demand_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=False, is_pod=False
)
on_demand_price = on_demand_base_price
spot_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=True, is_pod=False
)
on_demand_price = (
on_demand_base_price if on_demand_base_price is not None else None
)
spot_price = spot_base_price if spot_base_price is not None else None
base_no_of_chips = no_of_chips

spot_price = spot_base_price
else:
logger.warning("Unknown TPU version %s. Skipping offer.", tpu_version)
continue
if on_demand_price is None:
logger.debug("Failed to find on-demand price for %s in %s", instance_name, location)
if spot_price is None:
logger.debug("Failed to find spot price for %s in %s", instance_name, location)
config["price"] = on_demand_price
config["spot"] = spot_price
config["is_pod"] = is_pod
Expand All @@ -464,23 +475,16 @@ def get_tpu_prices() -> list[dict]:
client = CloudCatalogClient()
tpu_configs = []
# E000-3F24-B8AA contains prices for TPU versions v2,v3,v4.
# 6F81-5844-456A contains prices for TPU versions v5p and v5litepod(v5e)
# 6F81-5844-456A contains prices for newer TPU versions v5p, v5litepod(v5e), v6e.
service_names = ["services/E000-3F24-B8AA", "services/6F81-5844-456A"]

# Loop through each service name and list SKUs
for service_name in service_names:
# Create the request
request = ListSkusRequest(parent=service_name)

# List SKUs
response = client.list_skus(request=request)

for sku in response.skus:
if sku.category.resource_group != "TPU":
continue
if sku.category.usage_type not in ["OnDemand", "Preemptible"]:
continue

tpu_version = extract_tpu_version(sku.description)
if tpu_version:
is_pod = True if "Pod" in sku.description else False
Expand Down Expand Up @@ -524,21 +528,23 @@ def find_no_of_chips(instance_name: str, configs: list[dict]):
def find_tpu_price_static_src(
tpu_version: str, num_cores: int, tpu_region: str, no_of_chips: int, spot: bool
) -> Optional[float]:
# Pricing table names v5litepod as v5e
# The pricing page names v5litepod as v5e
tpu_version = "v5e" if tpu_version == "v5litepod" else tpu_version
is_pod = num_cores > 8 or tpu_version == "v4"
tpu_type = f"TPU {tpu_version} pod" if is_pod else f"TPU {tpu_version} device"
# Comment here
if tpu_version == "v5p" or tpu_version == "v5e":
tpu_type = f"TPU {tpu_version}"
# The pricing page lists different (device and pod) prices per chip for v2 and v3.
# The device is the smallest configuration with 8 TPU cores (e.g. v3-8).
# TPU Pod connects mulitple TPU devices (e.g. v3-32).
# Not applicable for newer generations.
tpu_type = f"TPU {tpu_version}"
if tpu_version in ["v2", "v3", "v4"]:
is_pod = num_cores > 8 or tpu_version == "v4"
tpu_type = f"TPU {tpu_version} pod" if is_pod else f"TPU {tpu_version} device"
price_key = "On Demand (USD)"
if spot:
price_key = "Spot (USD)"
try:
on_demand_price = tpu_pricing[tpu_type][tpu_region]["On Demand (USD)"] * no_of_chips
spot_price = tpu_pricing[tpu_type][tpu_region]["Spot (USD)"] * no_of_chips
return on_demand_price if not spot else spot_price
return TPU_PRICING_TABLE[tpu_type][tpu_region][price_key] * no_of_chips
except KeyError:
logger.warning(
f'key error for {tpu_type} {tpu_region} {"On Demand (USD)" if spot else "Spot (USD)"}'
)
logger.debug(f"KeyError for {tpu_type} {tpu_region} {price_key}")
return None


Expand All @@ -563,7 +569,6 @@ def get_tpu_configs(project_id: str) -> list[dict]:
request = tpu_v2.ListAcceleratorTypesRequest(
parent=parent,
)
# request = tpu_v1.ListAcceleratorTypesRequest(parent=parent)
page_result = client.list_accelerator_types(request=request)
for response in page_result:
no_of_chips = get_no_of_chips(response.accelerator_configs[0].topology)
Expand All @@ -579,11 +584,8 @@ def get_tpu_configs(project_id: str) -> list[dict]:


def get_no_of_chips(expression: str) -> int:
# Split the expression by 'x'
factors = expression.split("x")
# Convert each factor to an integer
factors = map(int, factors)
# Calculate the product by multiplying all factors
product = 1
for factor in factors:
product *= factor
Expand All @@ -592,11 +594,8 @@ def get_no_of_chips(expression: str) -> int:

def get_locations(project_id: str) -> list[str]:
client = tpu_v2.TpuClient()
# Initialize request argument(s)
parent = f"projects/{project_id}"
list_locations_request: ListLocationsResponse = client.list_locations(
locations_pb2.ListLocationsRequest(name=parent)
)
list_locations_request = client.list_locations(locations_pb2.ListLocationsRequest(name=parent))
locations = [loc.location_id for loc in list_locations_request.locations]
# TPU V4 only available in location us-central2-b only.
# us-central2-b needs to be enabled in the project.
Expand All @@ -606,7 +605,6 @@ def get_locations(project_id: str) -> list[str]:
def extract_tpu_version(input_string: str) -> Optional[str]:
# The regular expression pattern to find a substring starting with 'Tpu'
pattern = r"\bTpu[-\w]*\b"

# Search for the first match of the pattern
match = re.search(pattern, input_string, re.IGNORECASE)
if match:
Expand All @@ -618,5 +616,4 @@ def extract_tpu_version(input_string: str) -> Optional[str]:
# Name of v5e in gcp console is v5litepod
version = "v5litepod" if version_match.group() == "v5e" else version_match.group()
return version

return None
26 changes: 26 additions & 0 deletions src/gpuhunt/resources/tpu_pricing.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,30 @@
{
"TPU v6e": {
"us-east1": {
"Location": "South Carolina",
"On Demand (USD)": 2.7000,
"1-year Commitment (USD)": 1.8900,
"3-year Commitment (USD)": 1.2200
},
"us-east5": {
"Location": "Ohio",
"On Demand (USD)": 2.7000,
"1-year Commitment (USD)": 1.8900,
"3-year Commitment (USD)": 1.2200
},
"europe-west4": {
"Location": "Netherlands",
"On Demand (USD)": 2.9700,
"1-year Commitment (USD)": 2.0800,
"3-year Commitment (USD)": 1.3400
},
"asia-northeast1": {
"Location": "Tokio",
"On Demand (USD)": 3.2400,
"1-year Commitment (USD)": 2.2700,
"3-year Commitment (USD)": 1.4600
}
},
"TPU v5p": {
"us-east5": {
"Location": "Ohio",
Expand Down

0 comments on commit edc13c0

Please sign in to comment.