Skip to content

Commit

Permalink
Balance resources (#29)
Browse files Browse the repository at this point in the history
* Reimplement TensorDock.optimize_offers, rename fill_missing to balance_resources

* Implement balance_resources option for TensorDock

* Decrease RAM per core

* Revert RAM per core

* Fix typevar bound
  • Loading branch information
Egor-S authored Nov 28, 2023
1 parent cd8a473 commit 9ee8cf8
Show file tree
Hide file tree
Showing 16 changed files with 159 additions and 210 deletions.
2 changes: 1 addition & 1 deletion src/gpuhunt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gpuhunt._internal.catalog import Catalog
from gpuhunt._internal.constraints import fill_missing, matches
from gpuhunt._internal.constraints import matches
from gpuhunt._internal.default import default_catalog, query
from gpuhunt._internal.models import CatalogItem, QueryFilter, RawCatalogItem
16 changes: 8 additions & 8 deletions src/gpuhunt/_internal/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@


class Catalog:
def __init__(self, fill_missing: bool = True, auto_reload: bool = True):
def __init__(self, balance_resources: bool = True, auto_reload: bool = True):
"""
Args:
fill_missing: derive missing constraints from other constraints
balance_resources: increase min resources to better match the chosen GPU
auto_reload: if `True`, the catalog will be automatically loaded from the S3 bucket every 4 hours
"""
self.catalog = None
self.loaded_at = None
self.providers: List[AbstractProvider] = []
self.fill_missing = fill_missing
self.balance_resources = balance_resources
self.auto_reload = auto_reload

def query(
Expand Down Expand Up @@ -133,7 +133,6 @@ def query(
self._get_online_provider_items,
provider_name,
query_filter,
self.fill_missing,
)
)

Expand All @@ -144,7 +143,6 @@ def query(
self._get_offline_provider_items,
provider_name,
query_filter,
self.fill_missing,
)
)

Expand Down Expand Up @@ -186,7 +184,7 @@ def add_provider(self, provider: AbstractProvider):
self.providers.append(provider)

def _get_offline_provider_items(
self, provider_name: str, query_filter: QueryFilter, fill_missing: bool
self, provider_name: str, query_filter: QueryFilter
) -> List[CatalogItem]:
logger.debug("Loading items for offline provider %s", provider_name)

Expand All @@ -208,7 +206,7 @@ def _get_offline_provider_items(
return items

def _get_online_provider_items(
self, provider_name: str, query_filter: QueryFilter, fill_missing: bool
self, provider_name: str, query_filter: QueryFilter
) -> List[CatalogItem]:
logger.debug("Loading items for online provider %s", provider_name)
items = []
Expand All @@ -217,7 +215,9 @@ def _get_online_provider_items(
if provider.NAME != provider_name:
continue
found = True
for i in provider.get(query_filter=query_filter, fill_missing=fill_missing):
for i in provider.get(
query_filter=query_filter, balance_resources=self.balance_resources
):
item = CatalogItem(provider=provider_name, **dataclasses.asdict(i))
if constraints.matches(item, query_filter):
items.append(item)
Expand Down
64 changes: 0 additions & 64 deletions src/gpuhunt/_internal/constraints.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,7 @@
import copy
from typing import Optional, Tuple, TypeVar, Union

from gpuhunt._internal.models import CatalogItem, GPUInfo, QueryFilter


def fill_missing(q: QueryFilter, *, memory_per_core: int = 6) -> QueryFilter:
q = copy.deepcopy(q)

# if there is some information about gpu
min_total_gpu_memory = None
if any(
value is not None
for value in (
q.gpu_name,
q.min_gpu_count,
q.min_gpu_memory,
q.min_total_gpu_memory,
q.min_compute_capability,
)
):
if q.min_total_gpu_memory is not None:
min_total_gpu_memory = q.min_total_gpu_memory
else:
min_gpu_count = 1 if q.min_gpu_count is None else q.min_gpu_count
min_gpu_memory = []
if q.min_gpu_memory is not None:
min_gpu_memory.append(q.min_gpu_memory)
gpus = KNOWN_GPUS
if q.min_compute_capability is not None: # filter gpus by compute capability
gpus = [i for i in gpus if i.compute_capability >= q.min_compute_capability]
if q.gpu_name is not None: # filter gpus by name
gpus = [i for i in gpus if i.name.lower() in q.gpu_name]
min_gpu_memory.append(
min((i.memory for i in gpus), default=min(i.memory for i in KNOWN_GPUS))
)
min_total_gpu_memory = max(min_gpu_memory) * min_gpu_count

if min_total_gpu_memory is not None:
if q.min_memory is None: # gpu memory to memory
q.min_memory = 2 * min_total_gpu_memory
if q.min_disk_size is None: # gpu memory to disk
q.min_disk_size = 30 + min_total_gpu_memory

if q.min_memory is not None:
if q.min_cpu is None: # memory to cpu
q.min_cpu = (q.min_memory + memory_per_core - 1) // memory_per_core

if q.min_cpu is not None:
if q.min_memory is None: # cpu to memory
q.min_memory = memory_per_core * q.min_cpu

return q


Number = TypeVar("Number", bound=Union[int, float])


def optimize(
available: Number, min_limit: Number, max_limit: Optional[Number]
) -> Optional[Number]:
if is_above(available, max_limit):
available = max_limit
if is_below(available, min_limit):
return None
return min_limit


Comparable = TypeVar("Comparable", bound=Union[int, float, Tuple[int, int]])


Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class AbstractProvider(ABC):

@abstractmethod
def get(
self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
) -> List[RawCatalogItem]:
pass

Expand Down
4 changes: 3 additions & 1 deletion src/gpuhunt/providers/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def __init__(self, cache_path: Optional[str] = None):
"p4de.24xlarge": ("A100", 80.0),
}

def get(self, query_filter: Optional[QueryFilter] = None) -> List[RawCatalogItem]:
def get(
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
) -> List[RawCatalogItem]:
if not os.path.exists(self.cache_path):
logger.info("Downloading EC2 prices to %s", self.cache_path)
with requests.get(ec2_pricing_url, stream=True) as r:
Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _get_pages_worker(self, q: Queue, stride: int, worker_id: int):
q.put(None)

def get(
self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
) -> List[RawCatalogItem]:
offers = []
for page in self.get_pages():
Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/providers/datacrunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, client_id: str, client_secret: str) -> None:
self.datacrunch_client = DataCrunchClient(client_id, client_secret)

def get(
self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
) -> List[RawCatalogItem]:
instance_types = self._get_instance_types()
locations = self._get_locations()
Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/providers/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def fill_prices(self, instances: List[RawCatalogItem]) -> List[RawCatalogItem]:
return offers

def get(
self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
) -> List[RawCatalogItem]:
instances = self.list_preconfigured_instances()
self.add_gpus(instances)
Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/providers/lambdalabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, token: str):
self.token = token

def get(
self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
) -> List[RawCatalogItem]:
offers = []
data = requests.get(
Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/providers/nebius.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, service_account: "ServiceAccount"):
self.api_client = NebiusAPIClient(service_account)

def get(
self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
) -> List[RawCatalogItem]:
zone = self.api_client.compute_zones_list()[0]["id"]
skus = []
Expand Down
Loading

0 comments on commit 9ee8cf8

Please sign in to comment.