diff --git a/src/gpuhunt/__init__.py b/src/gpuhunt/__init__.py index 51ff23c..d6eb374 100644 --- a/src/gpuhunt/__init__.py +++ b/src/gpuhunt/__init__.py @@ -1,4 +1,4 @@ from gpuhunt._internal.catalog import Catalog -from gpuhunt._internal.constraints import fill_missing, matches +from gpuhunt._internal.constraints import matches from gpuhunt._internal.default import default_catalog, query from gpuhunt._internal.models import CatalogItem, QueryFilter, RawCatalogItem diff --git a/src/gpuhunt/_internal/catalog.py b/src/gpuhunt/_internal/catalog.py index bb8f31b..00f5f50 100644 --- a/src/gpuhunt/_internal/catalog.py +++ b/src/gpuhunt/_internal/catalog.py @@ -23,16 +23,16 @@ class Catalog: - def __init__(self, fill_missing: bool = True, auto_reload: bool = True): + def __init__(self, balance_resources: bool = True, auto_reload: bool = True): """ Args: - fill_missing: derive missing constraints from other constraints + balance_resources: increase min resources to better match the chosen GPU auto_reload: if `True`, the catalog will be automatically loaded from the S3 bucket every 4 hours """ self.catalog = None self.loaded_at = None self.providers: List[AbstractProvider] = [] - self.fill_missing = fill_missing + self.balance_resources = balance_resources self.auto_reload = auto_reload def query( @@ -133,7 +133,6 @@ def query( self._get_online_provider_items, provider_name, query_filter, - self.fill_missing, ) ) @@ -144,7 +143,6 @@ def query( self._get_offline_provider_items, provider_name, query_filter, - self.fill_missing, ) ) @@ -186,7 +184,7 @@ def add_provider(self, provider: AbstractProvider): self.providers.append(provider) def _get_offline_provider_items( - self, provider_name: str, query_filter: QueryFilter, fill_missing: bool + self, provider_name: str, query_filter: QueryFilter ) -> List[CatalogItem]: logger.debug("Loading items for offline provider %s", provider_name) @@ -208,7 +206,7 @@ def _get_offline_provider_items( return items def _get_online_provider_items( - self, provider_name: str, query_filter: QueryFilter, fill_missing: bool + self, provider_name: str, query_filter: QueryFilter ) -> List[CatalogItem]: logger.debug("Loading items for online provider %s", provider_name) items = [] @@ -217,7 +215,9 @@ def _get_online_provider_items( if provider.NAME != provider_name: continue found = True - for i in provider.get(query_filter=query_filter, fill_missing=fill_missing): + for i in provider.get( + query_filter=query_filter, balance_resources=self.balance_resources + ): item = CatalogItem(provider=provider_name, **dataclasses.asdict(i)) if constraints.matches(item, query_filter): items.append(item) diff --git a/src/gpuhunt/_internal/constraints.py b/src/gpuhunt/_internal/constraints.py index 46d6b9b..1b5fc46 100644 --- a/src/gpuhunt/_internal/constraints.py +++ b/src/gpuhunt/_internal/constraints.py @@ -1,71 +1,7 @@ -import copy from typing import Optional, Tuple, TypeVar, Union from gpuhunt._internal.models import CatalogItem, GPUInfo, QueryFilter - -def fill_missing(q: QueryFilter, *, memory_per_core: int = 6) -> QueryFilter: - q = copy.deepcopy(q) - - # if there is some information about gpu - min_total_gpu_memory = None - if any( - value is not None - for value in ( - q.gpu_name, - q.min_gpu_count, - q.min_gpu_memory, - q.min_total_gpu_memory, - q.min_compute_capability, - ) - ): - if q.min_total_gpu_memory is not None: - min_total_gpu_memory = q.min_total_gpu_memory - else: - min_gpu_count = 1 if q.min_gpu_count is None else q.min_gpu_count - min_gpu_memory = [] - if q.min_gpu_memory is not None: - min_gpu_memory.append(q.min_gpu_memory) - gpus = KNOWN_GPUS - if q.min_compute_capability is not None: # filter gpus by compute capability - gpus = [i for i in gpus if i.compute_capability >= q.min_compute_capability] - if q.gpu_name is not None: # filter gpus by name - gpus = [i for i in gpus if i.name.lower() in q.gpu_name] - min_gpu_memory.append( - min((i.memory for i in gpus), default=min(i.memory for i in KNOWN_GPUS)) - ) - min_total_gpu_memory = max(min_gpu_memory) * min_gpu_count - - if min_total_gpu_memory is not None: - if q.min_memory is None: # gpu memory to memory - q.min_memory = 2 * min_total_gpu_memory - if q.min_disk_size is None: # gpu memory to disk - q.min_disk_size = 30 + min_total_gpu_memory - - if q.min_memory is not None: - if q.min_cpu is None: # memory to cpu - q.min_cpu = (q.min_memory + memory_per_core - 1) // memory_per_core - - if q.min_cpu is not None: - if q.min_memory is None: # cpu to memory - q.min_memory = memory_per_core * q.min_cpu - - return q - - -Number = TypeVar("Number", bound=Union[int, float]) - - -def optimize( - available: Number, min_limit: Number, max_limit: Optional[Number] -) -> Optional[Number]: - if is_above(available, max_limit): - available = max_limit - if is_below(available, min_limit): - return None - return min_limit - - Comparable = TypeVar("Comparable", bound=Union[int, float, Tuple[int, int]]) diff --git a/src/gpuhunt/providers/__init__.py b/src/gpuhunt/providers/__init__.py index be79524..c5211df 100644 --- a/src/gpuhunt/providers/__init__.py +++ b/src/gpuhunt/providers/__init__.py @@ -9,7 +9,7 @@ class AbstractProvider(ABC): @abstractmethod def get( - self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> List[RawCatalogItem]: pass diff --git a/src/gpuhunt/providers/aws.py b/src/gpuhunt/providers/aws.py index c890b55..25ff749 100644 --- a/src/gpuhunt/providers/aws.py +++ b/src/gpuhunt/providers/aws.py @@ -69,7 +69,9 @@ def __init__(self, cache_path: Optional[str] = None): "p4de.24xlarge": ("A100", 80.0), } - def get(self, query_filter: Optional[QueryFilter] = None) -> List[RawCatalogItem]: + def get( + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True + ) -> List[RawCatalogItem]: if not os.path.exists(self.cache_path): logger.info("Downloading EC2 prices to %s", self.cache_path) with requests.get(ec2_pricing_url, stream=True) as r: diff --git a/src/gpuhunt/providers/azure.py b/src/gpuhunt/providers/azure.py index 4462ddd..a904faf 100644 --- a/src/gpuhunt/providers/azure.py +++ b/src/gpuhunt/providers/azure.py @@ -134,7 +134,7 @@ def _get_pages_worker(self, q: Queue, stride: int, worker_id: int): q.put(None) def get( - self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> List[RawCatalogItem]: offers = [] for page in self.get_pages(): diff --git a/src/gpuhunt/providers/datacrunch.py b/src/gpuhunt/providers/datacrunch.py index 4e296ad..8ec867c 100644 --- a/src/gpuhunt/providers/datacrunch.py +++ b/src/gpuhunt/providers/datacrunch.py @@ -17,7 +17,7 @@ def __init__(self, client_id: str, client_secret: str) -> None: self.datacrunch_client = DataCrunchClient(client_id, client_secret) def get( - self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> List[RawCatalogItem]: instance_types = self._get_instance_types() locations = self._get_locations() diff --git a/src/gpuhunt/providers/gcp.py b/src/gpuhunt/providers/gcp.py index 0f3f602..6f243e4 100644 --- a/src/gpuhunt/providers/gcp.py +++ b/src/gpuhunt/providers/gcp.py @@ -205,7 +205,7 @@ def fill_prices(self, instances: List[RawCatalogItem]) -> List[RawCatalogItem]: return offers def get( - self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> List[RawCatalogItem]: instances = self.list_preconfigured_instances() self.add_gpus(instances) diff --git a/src/gpuhunt/providers/lambdalabs.py b/src/gpuhunt/providers/lambdalabs.py index 59bd9ea..f3b408f 100644 --- a/src/gpuhunt/providers/lambdalabs.py +++ b/src/gpuhunt/providers/lambdalabs.py @@ -32,7 +32,7 @@ def __init__(self, token: str): self.token = token def get( - self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> List[RawCatalogItem]: offers = [] data = requests.get( diff --git a/src/gpuhunt/providers/nebius.py b/src/gpuhunt/providers/nebius.py index 052af30..ffa32ab 100644 --- a/src/gpuhunt/providers/nebius.py +++ b/src/gpuhunt/providers/nebius.py @@ -79,7 +79,7 @@ def __init__(self, service_account: "ServiceAccount"): self.api_client = NebiusAPIClient(service_account) def get( - self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> List[RawCatalogItem]: zone = self.api_client.compute_zones_list()[0]["id"] skus = [] diff --git a/src/gpuhunt/providers/tensordock.py b/src/gpuhunt/providers/tensordock.py index 5c69587..0da8858 100644 --- a/src/gpuhunt/providers/tensordock.py +++ b/src/gpuhunt/providers/tensordock.py @@ -1,10 +1,10 @@ import logging -from typing import List, Optional, Union +from math import ceil +from typing import List, Optional, TypeVar, Union import requests -from gpuhunt._internal.constraints import fill_missing as constraints_fill_missing -from gpuhunt._internal.constraints import get_compute_capability, is_between, optimize +from gpuhunt._internal.constraints import get_compute_capability, is_between from gpuhunt._internal.models import QueryFilter, RawCatalogItem from gpuhunt.providers import AbstractProvider @@ -31,26 +31,33 @@ "v100-pcie-16gb": "V100", } +RAM_PER_VRAM = 2 +RAM_PER_CORE = 6 +CPU_DIV = 2 # has to be even +RAM_DIV = 2 # has to be even + class TensorDockProvider(AbstractProvider): NAME = "tensordock" def get( - self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> List[RawCatalogItem]: logger.info("Fetching TensorDock offers") - if fill_missing: - query_filter = constraints_fill_missing(query_filter) - logger.debug("Effective query filter: %s", query_filter) - hostnodes = requests.get(marketplace_hostnodes_url).json()["hostnodes"] offers = [] for hostnode, details in hostnodes.items(): location = details["location"]["country"].lower().replace(" ", "") if query_filter is not None: - offers += self.optimize_offers(query_filter, details["specs"], hostnode, location) - else: + offers += self.optimize_offers( + query_filter, + details["specs"], + hostnode, + location, + balance_resources=balance_resources, + ) + else: # pick maximum possible configuration for gpu_name, gpu in details["specs"]["gpu"].items(): if gpu["amount"] == 0: continue @@ -79,8 +86,23 @@ def get( @staticmethod def optimize_offers( - q: QueryFilter, specs: dict, instance_name: str, location: str + q: QueryFilter, + specs: dict, + instance_name: str, + location: str, + balance_resources: bool = True, ) -> List[RawCatalogItem]: + """ + Picks the best offer for the given query filter + Doesn't respect max values, additional filtering is required + + Args: + q: query filter + specs: hostnode specs + instance_name: hostnode `instance_name` + location: hostnode `location` + balance_resources: if True, will override query filter min values + """ offers = [] for gpu_model, gpu_info in specs["gpu"].items(): # filter by single gpu characteristics @@ -89,46 +111,86 @@ def optimize_offers( gpu_name = convert_gpu_name(gpu_model) if q.gpu_name is not None and gpu_name.lower() not in q.gpu_name: continue - if q.min_compute_capability is not None or q.max_compute_capability is not None: - cc = get_compute_capability(gpu_name) - if not cc or not is_between( - cc, q.min_compute_capability, q.max_compute_capability - ): - continue + cc = get_compute_capability(gpu_name) + if not cc or not is_between(cc, q.min_compute_capability, q.max_compute_capability): + continue for gpu_count in range(1, gpu_info["amount"] + 1): # try all possible gpu counts if not is_between(gpu_count, q.min_gpu_count, q.max_gpu_count): continue + total_gpu_memory = gpu_count * gpu_info["vram"] if not is_between( - gpu_count * gpu_info["vram"], q.min_total_gpu_memory, q.max_total_gpu_memory + total_gpu_memory, q.min_total_gpu_memory, q.max_total_gpu_memory ): continue + # we can't take 100% of CPU/RAM/storage if we don't take all GPUs multiplier = 0.75 if gpu_count < gpu_info["amount"] else 1 - cpu = optimize( # has to be even - round_down(int(multiplier * specs["cpu"]["amount"]), 2), - round_up(q.min_cpu or 1, 2), - round_down(q.max_cpu, 2) if q.max_cpu is not None else None, - ) - memory = optimize( # has to be even - round_down(int(multiplier * specs["ram"]["amount"]), 2), - round_up(q.min_memory or 1, 2), - round_down(q.max_memory, 2) if q.max_memory is not None else None, - ) - disk_size = optimize( # 30 GB at least for Ubuntu - int(multiplier * specs["storage"]["amount"]), - q.min_disk_size or 30, - q.max_disk_size, - ) - if cpu is None or memory is None or disk_size is None: - continue + available_memory = round_down(multiplier * specs["ram"]["amount"], RAM_DIV) + available_cpu = round_down(multiplier * specs["cpu"]["amount"], CPU_DIV) + available_disk = int(multiplier * specs["storage"]["amount"]) + + memory = None + if q.min_memory is not None: + if q.min_memory > available_memory: + continue + memory = round_up( + max_none( + q.min_memory, + gpu_count, # 1 GB per GPU at least + q.min_cpu, # 1 GB per CPU at least + ), + RAM_DIV, + ) + if memory is None or balance_resources: + memory = max_none( + memory, + min_none( + available_memory, + round_up(RAM_PER_VRAM * total_gpu_memory, RAM_DIV), + round_down(q.max_memory, RAM_DIV), # can be None + ), + ) + + cpu = None + if q.min_cpu is not None: + if q.min_cpu > available_cpu: + continue + # 1 CPU per GPU at least + cpu = round_up(max(q.min_cpu, gpu_count), CPU_DIV) + if cpu is None or balance_resources: + cpu = max_none( + cpu, + min_none( + available_cpu, + round_up(ceil(memory / RAM_PER_CORE), CPU_DIV), + round_down(q.max_cpu, CPU_DIV), # can be None + ), + ) + + disk_size = None + if q.min_disk_size is not None: + if q.min_disk_size > available_disk: + continue + disk_size = q.min_disk_size + if disk_size is None or balance_resources: + disk_size = max_none( + disk_size, + min_none( + available_disk, + max(memory, total_gpu_memory), + q.max_disk_size, # can be None + ), + ) + price = round( - cpu * specs["cpu"]["price"] - + memory * specs["ram"]["price"] + memory * specs["ram"]["price"] + + cpu * specs["cpu"]["price"] + disk_size * specs["storage"]["price"] + gpu_count * gpu_info["price"], 5, ) + offer = RawCatalogItem( instance_name=instance_name, location=location, @@ -145,14 +207,29 @@ def optimize_offers( return offers -def round_up(value: Union[int, float], step: int) -> int: +def round_up(value: Optional[Union[int, float]], step: int) -> Optional[int]: + if value is None: + return None return round_down(value + step - 1, step) -def round_down(value: Union[int, float], step: int) -> int: +def round_down(value: Optional[Union[int, float]], step: int) -> Optional[int]: + if value is None: + return None return value // step * step +T = TypeVar("T", bound=Union[int, float]) + + +def min_none(*args: Optional[T]) -> T: + return min(v for v in args if v is not None) + + +def max_none(*args: Optional[T]) -> T: + return max(v for v in args if v is not None) + + def convert_gpu_name(model: str) -> str: """ >>> convert_gpu_name("geforcegtx1070-pcie-8gb") diff --git a/src/gpuhunt/providers/vastai.py b/src/gpuhunt/providers/vastai.py index b7f364e..7ab98cd 100644 --- a/src/gpuhunt/providers/vastai.py +++ b/src/gpuhunt/providers/vastai.py @@ -22,7 +22,7 @@ def __init__(self, extra_filters: Optional[Dict[str, Dict[Operators, FilterValue self.extra_filters = extra_filters def get( - self, query_filter: Optional[QueryFilter] = None, fill_missing: bool = True + self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> List[RawCatalogItem]: filters: Dict[str, Any] = self.make_filters(query_filter or QueryFilter()) if self.extra_filters: diff --git a/src/tests/_internal/test_catalog.py b/src/tests/_internal/test_catalog.py index ca77336..ee1fcc3 100644 --- a/src/tests/_internal/test_catalog.py +++ b/src/tests/_internal/test_catalog.py @@ -9,7 +9,7 @@ class TestQuery: def test_query_merge(self): - catalog = Catalog(fill_missing=False, auto_reload=False) + catalog = Catalog(balance_resources=False, auto_reload=False) tensordock = TensorDockProvider() tensordock.get = Mock(return_value=[catalog_item(price=1), catalog_item(price=3)]) @@ -27,7 +27,7 @@ def test_query_merge(self): ] def test_no_providers_some_not_loaded(self): - catalog = Catalog(fill_missing=False, auto_reload=False) + catalog = Catalog(balance_resources=False, auto_reload=False) tensordock = TensorDockProvider() tensordock.get = Mock(return_value=[catalog_item(price=1)]) diff --git a/src/tests/_internal/test_constraints.py b/src/tests/_internal/test_constraints.py index 4877530..244accf 100644 --- a/src/tests/_internal/test_constraints.py +++ b/src/tests/_internal/test_constraints.py @@ -3,7 +3,7 @@ import pytest from gpuhunt import CatalogItem, QueryFilter -from gpuhunt._internal.constraints import fill_missing, matches +from gpuhunt._internal.constraints import matches @pytest.fixture @@ -123,75 +123,3 @@ def test_ti_gpu(self): def test_provider(self, cpu_items): assert matches(cpu_items[0], QueryFilter(provider=["datacrunch"])) assert matches(cpu_items[1], QueryFilter(provider=["nebius"])) - - -class TestFillMissing: - def test_empty(self): - assert fill_missing(QueryFilter(), memory_per_core=4) == QueryFilter() - - def test_from_cpu(self): - assert fill_missing(QueryFilter(min_cpu=2), memory_per_core=4) == QueryFilter( - min_cpu=2, - min_memory=8, - ) - - def test_from_memory(self): - assert fill_missing(QueryFilter(min_memory=6), memory_per_core=4) == QueryFilter( - min_memory=6, - min_cpu=2, - ) - - def test_from_total_gpu_memory(self): - assert fill_missing( - QueryFilter(min_total_gpu_memory=24), memory_per_core=4 - ) == QueryFilter( - min_total_gpu_memory=24, - min_memory=48, - min_disk_size=54, - min_cpu=12, - ) - - def test_from_gpu_memory(self): - assert fill_missing(QueryFilter(min_gpu_memory=16), memory_per_core=4) == QueryFilter( - min_gpu_memory=16, - min_memory=32, - min_disk_size=46, - min_cpu=8, - ) - - def test_from_gpu_count(self): - assert fill_missing(QueryFilter(min_gpu_count=2), memory_per_core=4) == QueryFilter( - min_gpu_count=2, - min_memory=32, # minimal GPU has 8 GB of memory - min_disk_size=46, - min_cpu=8, - ) - - def test_from_gpu_name(self): - assert fill_missing(QueryFilter(gpu_name=["A100"]), memory_per_core=4) == QueryFilter( - gpu_name=["A100"], - min_memory=80, - min_disk_size=70, - min_cpu=20, - ) - - def test_from_compute_capability(self): - assert fill_missing( - QueryFilter(min_compute_capability=(9, 0)), memory_per_core=4 - ) == QueryFilter( - min_compute_capability=(9, 0), - min_memory=160, - min_disk_size=110, - min_cpu=40, - ) - - def test_from_gpu_name_and_gpu_memory(self): - assert fill_missing( - QueryFilter(gpu_name=["A100"], min_gpu_memory=80), memory_per_core=4 - ) == QueryFilter( - gpu_name=["A100"], - min_gpu_memory=80, - min_memory=160, - min_disk_size=110, - min_cpu=40, - ) diff --git a/src/tests/providers/test_datacrunch.py b/src/tests/providers/test_datacrunch.py index 3c930e6..e494673 100644 --- a/src/tests/providers/test_datacrunch.py +++ b/src/tests/providers/test_datacrunch.py @@ -187,7 +187,7 @@ def transform(raw_catalog_items: List[RawCatalogItem]) -> List[CatalogItem]: def test_available_query(mocker, raw_instance_types): - catalog = Catalog(fill_missing=False, auto_reload=False) + catalog = Catalog(balance_resources=False, auto_reload=False) instance_type = instance_types(raw_instance_types[0]) @@ -233,7 +233,7 @@ def test_available_query(mocker, raw_instance_types): def test_available_query_with_instance(mocker, raw_instance_types): - catalog = Catalog(fill_missing=False, auto_reload=False) + catalog = Catalog(balance_resources=False, auto_reload=False) instance_type = instance_types(raw_instance_types[-1]) print(instance_type) @@ -323,7 +323,7 @@ def test_cpu_instance(raw_instance_types): def test_order(mocker, raw_instance_types): - catalog = Catalog(fill_missing=False, auto_reload=False) + catalog = Catalog(balance_resources=False, auto_reload=False) types = map(instance_types, raw_instance_types) diff --git a/src/tests/providers/test_tensordock.py b/src/tests/providers/test_tensordock.py index 01aacad..3eefe87 100644 --- a/src/tests/providers/test_tensordock.py +++ b/src/tests/providers/test_tensordock.py @@ -29,33 +29,39 @@ def specs() -> dict: class TestTensorDockMinimalConfiguration: def test_no_requirements(self, specs: dict): offers = TensorDockProvider.optimize_offers(QueryFilter(), specs, "", "") - assert offers == make_offers(specs, cpu=2, memory=2, disk_size=30, gpu_count=1) + assert offers == make_offers(specs, cpu=16, memory=96, disk_size=96, gpu_count=1) + + def test_min_cpu_no_balance(self, specs: dict): + offers = TensorDockProvider.optimize_offers( + QueryFilter(min_cpu=4), specs, "", "", balance_resources=False + ) + assert offers == make_offers(specs, cpu=4, memory=96, disk_size=96, gpu_count=1) def test_min_cpu(self, specs: dict): offers = TensorDockProvider.optimize_offers(QueryFilter(min_cpu=4), specs, "", "") - assert offers == make_offers(specs, cpu=4, memory=2, disk_size=30, gpu_count=1) + assert offers == make_offers(specs, cpu=16, memory=96, disk_size=96, gpu_count=1) def test_too_many_min_cpu(self, specs: dict): offers = TensorDockProvider.optimize_offers(QueryFilter(min_cpu=1000), specs, "", "") assert offers == [] + def test_min_memory_no_balance(self, specs: dict): + offers = TensorDockProvider.optimize_offers( + QueryFilter(min_memory=3), specs, "", "", balance_resources=False + ) + assert offers == make_offers(specs, cpu=2, memory=4, disk_size=48, gpu_count=1) + def test_min_memory(self, specs: dict): offers = TensorDockProvider.optimize_offers(QueryFilter(min_memory=3), specs, "", "") - assert offers == make_offers(specs, cpu=2, memory=4, disk_size=30, gpu_count=1) + assert offers == make_offers(specs, cpu=16, memory=96, disk_size=96, gpu_count=1) def test_too_large_min_memory(self, specs: dict): offers = TensorDockProvider.optimize_offers(QueryFilter(min_memory=2000), specs, "", "") assert offers == [] - def test_controversial_cpu(self, specs: dict): - offers = TensorDockProvider.optimize_offers( - QueryFilter(min_memory=8, max_memory=4), specs, "", "" - ) - assert offers == [] - def test_min_gpu_count(self, specs: dict): offers = TensorDockProvider.optimize_offers(QueryFilter(min_gpu_count=2), specs, "", "") - assert offers == make_offers(specs, cpu=2, memory=2, disk_size=30, gpu_count=2) + assert offers == make_offers(specs, cpu=32, memory=192, disk_size=192, gpu_count=2) def test_min_no_gpu(self, specs: dict): offers = TensorDockProvider.optimize_offers(QueryFilter(max_gpu_count=0), specs, "", "") @@ -65,7 +71,7 @@ def test_min_total_gpu_memory(self, specs: dict): offers = TensorDockProvider.optimize_offers( QueryFilter(min_total_gpu_memory=100), specs, "", "" ) - assert offers == make_offers(specs, cpu=2, memory=2, disk_size=30, gpu_count=3) + assert offers == make_offers(specs, cpu=48, memory=288, disk_size=288, gpu_count=3) def test_controversial_gpu(self, specs: dict): offers = TensorDockProvider.optimize_offers( @@ -77,7 +83,7 @@ def test_all_cpu_all_gpu(self, specs: dict): offers = TensorDockProvider.optimize_offers( QueryFilter(min_cpu=256, min_gpu_count=1), specs, "", "" ) - assert offers == make_offers(specs, cpu=256, memory=2, disk_size=30, gpu_count=8) + assert offers == make_offers(specs, cpu=256, memory=768, disk_size=768, gpu_count=8) def make_offers(