Skip to content

Commit

Permalink
Add TPU support in GCP
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihan Rana authored and Bihan Rana committed Jun 4, 2024
1 parent e1effee commit ad703ef
Show file tree
Hide file tree
Showing 5 changed files with 433 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ azure = [
]
gcp = [
"google-cloud-billing",
"google-cloud-compute"
"google-cloud-compute",
"google-cloud-tpu"
]
nebius = [
"pyjwt",
Expand Down
287 changes: 287 additions & 0 deletions src/gpuhunt/providers/gcp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import copy
import importlib
import json
import logging
import re
from collections import defaultdict, namedtuple
from typing import List, Optional

import google.cloud.billing_v1 as billing_v1
import google.cloud.compute_v1 as compute_v1
from google.cloud import tpu_v2
from google.cloud.billing_v1 import CloudCatalogClient, ListSkusRequest
from google.cloud.location import locations_pb2
from google.cloud.location.locations_pb2 import ListLocationsResponse

from gpuhunt._internal.models import QueryFilter, RawCatalogItem
from gpuhunt.providers import AbstractProvider
Expand Down Expand Up @@ -44,6 +50,17 @@
accelerator_counts = [1, 2, 4, 8, 16]


def load_tpu_pricing():
resource_package = "gpuhunt.resources"
resource_name = "tpu_pricing.json"

with importlib.resources.open_text(resource_package, resource_name) as f:
return json.load(f)


tpu_pricing: dict = load_tpu_pricing()


class GCPProvider(AbstractProvider):
NAME = "gcp"

Expand Down Expand Up @@ -211,6 +228,8 @@ def get(
instances = self.list_preconfigured_instances()
self.add_gpus(instances)
offers = self.fill_prices(instances)
# Add tpu offerings
offers.extend(get_tpu_offers(self.project))
return sorted(offers, key=lambda i: i.price)

@classmethod
Expand All @@ -232,3 +251,271 @@ def filter(cls, offers: List[RawCatalogItem]) -> List[RawCatalogItem]:
)
or (i.gpu_name and i.gpu_name not in ["K80", "P4"])
]


def get_tpu_offers(project_id: str) -> List[RawCatalogItem]:
logger.info("Fetching tpu offers")
raw_catalog_items: List[RawCatalogItem] = []
catalog_items: List[dict] = get_catalog_items(project_id)
filtered_catalog_items = [item for item in catalog_items if item["price"] is not None]
for item in filtered_catalog_items:
on_demand_item = RawCatalogItem(
instance_name=item["instance_name"],
location=item["location"],
price=item["price"],
cpu=0,
memory=0,
gpu_count=1,
gpu_name=f'tpu-{item["instance_name"]}',
gpu_memory=0,
spot=False,
disk_size=None,
)
raw_catalog_items.append(on_demand_item)
if item["spot"]:
spot_item = copy.deepcopy(on_demand_item)
spot_item.price = item["spot"]
spot_item.spot = True
raw_catalog_items.append(spot_item)
return raw_catalog_items


def get_catalog_items(project_id: str) -> List[dict]:
tpu_prices: List[dict] = get_tpu_prices()
configs: List[dict] = get_tpu_configs(project_id)
for config in configs:
instance_name = config["instance_name"]
location = config["location"].rsplit("-", 1)[0] # Remove the part after the last '-'
no_of_chips = config["no_of_chips"]
tpu_version, no_of_cores = instance_name.rsplit("-", 1)
no_of_cores = int(no_of_cores)
if tpu_version in ["v5litepod", "v5p"]:
# For TPU-v5 series, api provides per chip price.
# Verify per chip price in the following link.https://cloud.google.com/tpu/pricing
is_pod = True
on_demand_base_price = find_base_price_v5(
tpu_version, location, tpu_prices, spot=False
)
spot_base_price = find_base_price_v5(tpu_version, location, tpu_prices, spot=True)
if on_demand_base_price is not None:
on_demand_price = on_demand_base_price * no_of_chips
else:
on_demand_price = find_tpu_price_static_src(
tpu_version, no_of_cores, location, no_of_chips, False
)
if spot_base_price is not None:
spot_price = spot_base_price * no_of_chips
else:
spot_price = find_tpu_price_static_src(
tpu_version, no_of_cores, location, no_of_chips, True
)
elif tpu_version in ["v2", "v3", "v4"]:
# For TPU-v2 and TPU-v3, the pricing API provides the prices of 8 TPU cores.
# For TPU-v4, api only provides the price of TPU-v4 pods.
if no_of_cores > 8 or tpu_version == "v4":
is_pod = True
base_instance_name = f"{tpu_version}-8"
base_no_of_chips = find_no_of_chips(base_instance_name, configs)
on_demand_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=False, is_pod=True
)
spot_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=True, is_pod=True
)

if on_demand_base_price is not None and base_no_of_chips is not None:
on_demand_price = (on_demand_base_price / base_no_of_chips) * no_of_chips
else:
on_demand_price = find_tpu_price_static_src(
tpu_version, no_of_cores, location, no_of_chips, False
)
if spot_base_price is not None and base_no_of_chips is not None:
spot_price = (spot_base_price / base_no_of_chips) * no_of_chips
else:
spot_price = find_tpu_price_static_src(
tpu_version, no_of_cores, location, no_of_chips, True
)

elif no_of_cores == 8:
is_pod = False
on_demand_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=False, is_pod=False
)
spot_base_price = find_base_price(
tpu_version, location, tpu_prices, spot=True, is_pod=False
)
on_demand_price = (
on_demand_base_price if on_demand_base_price is not None else None
)
spot_price = spot_base_price if spot_base_price is not None else None
base_no_of_chips = no_of_chips

config["price"] = on_demand_price
config["spot"] = spot_price
config["is_pod"] = is_pod
config["base_price"] = on_demand_base_price
config["base_no_of_chips"] = base_no_of_chips
config["spot_base_price"] = spot_base_price
return configs


def get_tpu_prices() -> List[dict]:
client = CloudCatalogClient()
tpu_configs = []
# E000-3F24-B8AA contains prices for TPU versions v2,v3,v4.
# 6F81-5844-456A contains prices for TPU versions v5p and v5litepod(v5e)
service_names = ["services/E000-3F24-B8AA", "services/6F81-5844-456A"]

# Loop through each service name and list SKUs
for service_name in service_names:
# Create the request
request = ListSkusRequest(parent=service_name)

# List SKUs
response = client.list_skus(request=request)

for sku in response.skus:
if sku.category.resource_group != "TPU":
continue
if sku.category.usage_type not in ["OnDemand", "Preemptible"]:
continue

tpu_version = extract_tpu_version(sku.description)
if tpu_version:
is_pod = True if "Pod" in sku.description else False
spot = True if "Preemptible" in sku.description else False
price = sku.pricing_info[0].pricing_expression.tiered_rates[0].unit_price
price = price.units + price.nanos / 1e9
tpu_configs.append(
{
"instance_name": tpu_version,
"is_pod": is_pod,
"spot": spot,
"regions": sku.service_regions,
"price": price,
"description": sku.description,
}
)
return tpu_configs


def find_base_price(
instance_name: str, location: str, tpu_prices: List[dict], spot: bool, is_pod: bool
) -> Optional[float]:
for price_info in tpu_prices:
if (
price_info["instance_name"] == instance_name
and any(loc.startswith(location) for loc in price_info["regions"])
and price_info["spot"] == spot
and price_info["is_pod"] == is_pod
):
return price_info["price"]
return None


def find_no_of_chips(instance_name: str, configs: List[dict]):
for config in configs:
if config["instance_name"] == instance_name:
return config["no_of_chips"]
return None


def find_tpu_price_static_src(
tpu_version: str, num_cores: int, tpu_region: str, no_of_chips: int, spot: bool
) -> Optional[float]:
# Pricing table names v5litepod as v5e
tpu_version = "v5e" if tpu_version == "v5litepod" else tpu_version
is_pod = num_cores > 8 or tpu_version == "v4"
tpu_type = f"TPU {tpu_version} pod" if is_pod else f"TPU {tpu_version} device"
# Comment here
if tpu_version == "v5p" or tpu_version == "v5e":
tpu_type = f"TPU {tpu_version}"
try:
on_demand_price = tpu_pricing[tpu_type][tpu_region]["On Demand (USD)"] * no_of_chips
spot_price = tpu_pricing[tpu_type][tpu_region]["Spot (USD)"] * no_of_chips
return on_demand_price if not spot else spot_price
except KeyError:
logger.warning(
f'key error for {tpu_type} {tpu_region} {"On Demand (USD)" if spot else "Spot (USD)"}'
)
return None


def find_base_price_v5(
instance_name: str, location: str, tpu_prices: List[dict], spot: bool
) -> Optional[float]:
for price_info in tpu_prices:
if (
price_info["instance_name"] == instance_name
and any(loc.startswith(location) for loc in price_info["regions"])
and price_info["spot"] == spot
):
return price_info["price"]
return None


def get_tpu_configs(project_id: str) -> List[dict]:
instances: List[dict] = []
client = tpu_v2.TpuClient()
for location in get_locations(project_id):
parent = f"projects/{project_id}/locations/{location}"
request = tpu_v2.ListAcceleratorTypesRequest(
parent=parent,
)
# request = tpu_v1.ListAcceleratorTypesRequest(parent=parent)
page_result = client.list_accelerator_types(request=request)
for response in page_result:
no_of_chips = get_no_of_chips(response.accelerator_configs[0].topology)
instances.append(
{
"instance_name": response.type_,
"location": location,
"no_of_chips": no_of_chips,
"topology": response.accelerator_configs[0].topology,
}
)
return instances


def get_no_of_chips(expression: str) -> int:
# Split the expression by 'x'
factors = expression.split("x")
# Convert each factor to an integer
factors = map(int, factors)
# Calculate the product by multiplying all factors
product = 1
for factor in factors:
product *= factor
return product


def get_locations(project_id: str) -> List[str]:
client = tpu_v2.TpuClient()
# Initialize request argument(s)
parent = f"projects/{project_id}"
list_locations_request: ListLocationsResponse = client.list_locations(
locations_pb2.ListLocationsRequest(name=parent)
)
locations = [loc.location_id for loc in list_locations_request.locations]
# TPU V4 only available in location us-central2-b only.
# us-central2-b needs to be enabled in the project.
return locations


def extract_tpu_version(input_string: str) -> Optional[str]:
# The regular expression pattern to find a substring starting with 'Tpu'
pattern = r"\bTpu[-\w]*\b"

# Search for the first match of the pattern
match = re.search(pattern, input_string, re.IGNORECASE)
if match:
tpu_match = match.group().lower()
# The regular expression pattern to find the version part
version_pattern = r"v\d+[a-z]*"
version_match = re.search(version_pattern, tpu_match)
if version_match:
# Name of v5e in gcp console is v5litepod
version = "v5litepod" if version_match.group() == "v5e" else version_match.group()
return version

return None
Empty file.
Loading

0 comments on commit ad703ef

Please sign in to comment.