diff --git a/src/gpuhunt/providers/gcp.py b/src/gpuhunt/providers/gcp.py index 35e60b6..12332b8 100644 --- a/src/gpuhunt/providers/gcp.py +++ b/src/gpuhunt/providers/gcp.py @@ -5,6 +5,7 @@ import re from collections import defaultdict, namedtuple from collections.abc import Iterable +from dataclasses import dataclass from typing import Optional import google.cloud.billing_v1 as billing_v1 @@ -13,7 +14,6 @@ from google.cloud.billing_v1 import CloudCatalogClient, ListSkusRequest from google.cloud.billing_v1.types.cloud_catalog import Sku from google.cloud.location import locations_pb2 -from pydantic import BaseModel from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem from gpuhunt.providers import AbstractProvider @@ -90,13 +90,27 @@ } -class TPUHardwareSpec(BaseModel): +@dataclass +class TPUHardwareSpec: name: str cpu: int memory_gb: int hbm_gb: int +TPU_HARDWARE_SPECS = [ + TPUHardwareSpec(name="v2-8", cpu=96, memory_gb=334, hbm_gb=64), + TPUHardwareSpec(name="v3-8", cpu=96, memory_gb=334, hbm_gb=128), + TPUHardwareSpec(name="v5litepod-1", cpu=24, memory_gb=48, hbm_gb=16), + TPUHardwareSpec(name="v5litepod-2", cpu=112, memory_gb=192, hbm_gb=16), + TPUHardwareSpec(name="v5litepod-8", cpu=224, memory_gb=384, hbm_gb=128), + TPUHardwareSpec(name="v5p-8", cpu=208, memory_gb=448, hbm_gb=95), + TPUHardwareSpec(name="v6e-1", cpu=44, memory_gb=176, hbm_gb=32), + TPUHardwareSpec(name="v6e-4", cpu=180, memory_gb=720, hbm_gb=128), + TPUHardwareSpec(name="v6e-8", cpu=180, memory_gb=1440, hbm_gb=256), +] + + # For newer TPUs, the specs are described in the docs: https://cloud.google.com/tpu/docs/v6e # For older TPUs, the specs are collected manually from running instances. TPU_HARDWARE_SPECS = [