From 23312da788bb333c8f2006de5b5b078223e6163d Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 12 Jun 2024 18:58:04 +0545 Subject: [PATCH 1/2] Add minimum constraints for TPU --- src/gpuhunt/_internal/catalog.py | 14 ++++++++++++-- src/gpuhunt/_internal/constraints.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/gpuhunt/_internal/catalog.py b/src/gpuhunt/_internal/catalog.py index 78675f7..3a83e64 100644 --- a/src/gpuhunt/_internal/catalog.py +++ b/src/gpuhunt/_internal/catalog.py @@ -201,8 +201,18 @@ def _get_offline_provider_items( ) for row in reader: item = CatalogItem.from_dict(row, provider=provider_name) - if constraints.matches(item, query_filter): - items.append(item) + # tpus does not specify cpu and memory hence different + # constraints matching is required. + if query_filter.gpu_name is not None: + if any("tpu" in name for name in query_filter.gpu_name): + if constraints.tpu_matches(item, query_filter): + items.append(item) + else: + if constraints.matches(item, query_filter): + items.append(item) + else: + if constraints.matches(item, query_filter): + items.append(item) return items def _get_online_provider_items( diff --git a/src/gpuhunt/_internal/constraints.py b/src/gpuhunt/_internal/constraints.py index abfd61f..be90a71 100644 --- a/src/gpuhunt/_internal/constraints.py +++ b/src/gpuhunt/_internal/constraints.py @@ -69,6 +69,22 @@ def matches(i: CatalogItem, q: QueryFilter) -> bool: return True +def tpu_matches(i: CatalogItem, q: QueryFilter) -> bool: + if q.gpu_name is not None: + if i.gpu_name is None: + return False + if i.gpu_name.lower() not in q.gpu_name: + return False + if i.disk_size is not None: + if not is_between(i.disk_size, q.min_disk_size, q.max_disk_size): + return False + if not is_between(i.price, q.min_price, q.max_price): + return False + if q.spot is not None and i.spot != q.spot: + return False + return True + + def get_compute_capability(gpu_name: str) -> Optional[Tuple[int, int]]: for gpu in KNOWN_GPUS: if gpu.name.lower() == gpu_name.lower(): From 09ff973a0a057e77b96ba045037e01d03bd71ff3 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 12 Jun 2024 18:58:04 +0545 Subject: [PATCH 2/2] Add minimum constraints for TPU --- src/gpuhunt/_internal/catalog.py | 4 ++-- src/gpuhunt/_internal/utils.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/gpuhunt/_internal/catalog.py b/src/gpuhunt/_internal/catalog.py index 3a83e64..1678053 100644 --- a/src/gpuhunt/_internal/catalog.py +++ b/src/gpuhunt/_internal/catalog.py @@ -11,7 +11,7 @@ import gpuhunt._internal.constraints as constraints from gpuhunt._internal.models import CatalogItem, QueryFilter -from gpuhunt._internal.utils import parse_compute_capability +from gpuhunt._internal.utils import _is_tpu, parse_compute_capability from gpuhunt.providers import AbstractProvider logger = logging.getLogger(__name__) @@ -204,7 +204,7 @@ def _get_offline_provider_items( # tpus does not specify cpu and memory hence different # constraints matching is required. if query_filter.gpu_name is not None: - if any("tpu" in name for name in query_filter.gpu_name): + if any(_is_tpu(name) for name in query_filter.gpu_name): if constraints.tpu_matches(item, query_filter): items.append(item) else: diff --git a/src/gpuhunt/_internal/utils.py b/src/gpuhunt/_internal/utils.py index 04a5abe..c16b2f5 100644 --- a/src/gpuhunt/_internal/utils.py +++ b/src/gpuhunt/_internal/utils.py @@ -23,3 +23,14 @@ def to_camel_case(snake_case: str) -> str: words = list(filter(None, words)) words[1:] = [word[:1].upper() + word[1:] for word in words[1:]] return "".join(words) + + +def _is_tpu(name: str) -> bool: + tpu_versions = ["tpu-v2", "tpu-v3", "tpu-v4", "tpu-v5p", "tpu-v5litepod"] + parts = name.split("-") + if len(parts) == 3: + version = f"{parts[0]}-{parts[1]}" + cores = parts[2] + if version in tpu_versions and cores.isdigit(): + return True + return False