Skip to content

Commit

Permalink
Add Nebius provider (#7)
Browse files Browse the repository at this point in the history
* Implement Nebius API Client & Provider

* Collect Nebius catalog

* Configure catalog channel

* Change default channel for manual workflow trigger

* Update readme and catalog. Fix nebius spots pricing
  • Loading branch information
Egor-S authored Nov 7, 2023
1 parent 785e71a commit 5cb42c2
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 6 deletions.
43 changes: 40 additions & 3 deletions .github/workflows/catalogs.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: Collect and publish catalogs
on:
workflow_dispatch:
inputs:
channel:
description: 'Channel to publish catalogs to'
required: true
default: stgn
schedule:
- cron: '0 1 * * *' # 01:00 UTC every day

Expand Down Expand Up @@ -115,13 +120,37 @@ jobs:
path: lambdalabs.csv
retention-days: 1

catalog-nebius:
name: Collect Nebius catalog
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install dependencies
run: |
pip install pip -U
pip install -e '.[nebius]'
- name: Collect catalog
working-directory: src
env:
NEBIUS_SERVICE_ACCOUNT: ${{ secrets.NEBIUS_SERVICE_ACCOUNT }}
run: python -m gpuhunt nebius --output ../nebius.csv
- uses: actions/upload-artifact@v3
with:
name: catalogs
path: nebius.csv
retention-days: 1

test-catalog:
name: Test catalogs integrity
needs:
- catalog-aws
- catalog-azure
- catalog-gcp
- catalog-lambdalabs
- catalog-nebius
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -154,12 +183,20 @@ jobs:
run: date +%Y%m%d > version
- name: Package catalogs
run: zip catalog.zip *.csv version
- name: Set channel
run: |
if [[ ${{ github.event_name == 'workflow_dispatch' }} == true ]]; then
CHANNEL=${{ inputs.channel }}
else
CHANNEL=${{ vars.CHANNEL }}
fi
echo "CHANNEL=$CHANNEL" >> $GITHUB_ENV
- name: Upload to S3
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
VERSION=$(cat version)
aws s3 cp catalog.zip "s3://dstack-gpu-pricing/v1/$VERSION/catalog.zip" --acl public-read
cat version | aws s3 cp - "s3://dstack-gpu-pricing/v1/version" --acl public-read
aws s3 cp "s3://dstack-gpu-pricing/v1/$VERSION/catalog.zip" "s3://dstack-gpu-pricing/v1/latest/catalog.zip" --acl public-read
aws s3 cp catalog.zip "s3://dstack-gpu-pricing/$CHANNEL/$VERSION/catalog.zip" --acl public-read
cat version | aws s3 cp - "s3://dstack-gpu-pricing/$CHANNEL/version" --acl public-read
aws s3 cp "s3://dstack-gpu-pricing/$CHANNEL/$VERSION/catalog.zip" "s3://dstack-gpu-pricing/$CHANNEL/latest/catalog.zip" --acl public-read
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ print(*items, sep="\n")
* Azure
* GCP
* LambdaLabs
* Nebius
* TensorDock
* Vast AI

Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ gcp = [
"google-cloud-billing",
"google-cloud-compute"
]
all = ["gpuhunt[aws,azure,gcp]"]
nebius = [
"pyjwt",
"cryptography",
"beautifulsoup4"
]
all = ["gpuhunt[aws,azure,gcp,nebius]"]

[tool.setuptools.dynamic]
version = {attr = "gpuhunt.version.__version__"}
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ google-cloud-billing
google-cloud-compute
azure-mgmt-compute
azure-identity
beautifulsoup4
pyjwt
cryptography
7 changes: 6 additions & 1 deletion src/gpuhunt/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json
import logging
import os
import sys
Expand All @@ -10,7 +11,7 @@ def main():
parser = argparse.ArgumentParser(prog="python3 -m gpuhunt")
parser.add_argument(
"provider",
choices=["aws", "azure", "gcp", "lambdalabs", "tensordock", "vastai"],
choices=["aws", "azure", "gcp", "lambdalabs", "nebius", "tensordock", "vastai"],
)
parser.add_argument("--output", required=True)
parser.add_argument("--no-filter", action="store_true")
Expand All @@ -37,6 +38,10 @@ def main():
from gpuhunt.providers.lambdalabs import LambdaLabsProvider

provider = LambdaLabsProvider(os.getenv("LAMBDALABS_TOKEN"))
elif args.provider == "nebius":
from gpuhunt.providers.nebius import NebiusProvider

provider = NebiusProvider(json.loads(os.getenv("NEBIUS_SERVICE_ACCOUNT")))
elif args.provider == "tensordock":
from gpuhunt.providers.tensordock import TensorDockProvider

Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/_internal/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger = logging.getLogger(__name__)
version_url = "https://dstack-gpu-pricing.s3.eu-west-1.amazonaws.com/v1/version"
catalog_url = "https://dstack-gpu-pricing.s3.eu-west-1.amazonaws.com/v1/{version}/catalog.zip"
OFFLINE_PROVIDERS = ["aws", "azure", "gcp", "lambdalabs"]
OFFLINE_PROVIDERS = ["aws", "azure", "gcp", "lambdalabs", "nebius"]
ONLINE_PROVIDERS = ["tensordock", "vastai"]
RELOAD_INTERVAL = 4 * 60 * 60 # 4 hours

Expand Down
247 changes: 247 additions & 0 deletions src/gpuhunt/providers/nebius.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import datetime
import logging
import re
import time
from collections import defaultdict
from typing import Dict, List, Literal, Optional, TypedDict

import bs4
import jwt
import requests

from gpuhunt import QueryFilter, RawCatalogItem
from gpuhunt.providers import AbstractProvider

logger = logging.getLogger(__name__)
API_URL = "api.ai.nebius.cloud"
COMPUTE_SERVICE_ID = "bfa2pas77ftg9h3f2djj"
GPU_NAME_PLATFORM = {
"A100": "gpu-standard-v3",
"H100 PCIe": "standard-v3-h100-pcie",
"Hopper H100 SXM (Type A)": "gpu-h100",
"Hopper H100 SXM (Type B)": "gpu-h100-b",
"L4": "standard-v3-l4",
"L40": "standard-v3-l40",
None: "standard-v2",
}


class NebiusProvider(AbstractProvider):
def __init__(self, service_account: "ServiceAccount"):
self.api_client = NebiusAPIClient(service_account)

def get(self, query_filter: Optional[QueryFilter] = None) -> List[RawCatalogItem]:
zone = self.api_client.compute_zones_list()[0]["id"]
skus = []
page_token = None
logger.info("Fetching SKUs")
while True:
page = self.api_client.billing_skus_list(
filter=f'serviceId="{COMPUTE_SERVICE_ID}"', page_token=page_token
)
skus += page["skus"]
page_token = page.get("nextPageToken")
if page_token is None:
break
platform_resources = self.aggregate_skus(skus)
return self.get_gpu_platforms(zone, platform_resources) + self.get_cpu_platforms(
zone, platform_resources
)

def get_gpu_platforms(
self, zone: str, platform_resources: "PlatformResourcePrice"
) -> List[RawCatalogItem]:
logger.info("Fetching GPU platforms")
resp = requests.get("https://nebius.ai/docs/compute/concepts/gpus")
resp.raise_for_status()
soup = bs4.BeautifulSoup(resp.text, "html.parser")
configs = soup.find("h2", id="config").find_next_sibling("ul").find_all("li")
items = []
for li in configs:
platform = li.find("p").find("code").text
prices = platform_resources[platform]
gpu_name = re.search(r" ([A-Z]+\d+) ", li.find("p").text).group(1)
for tr in li.find("tbody").find_all("tr"):
tds = tr.find_all("td")
gpu_count = int(tds[0].text.strip(" *"))
cpu = int(tds[2].text)
memory = float(tds[3].text)
items.append(
RawCatalogItem(
instance_name=platform,
location=zone,
price=round(
cpu * prices["cpu"]
+ memory * prices["ram"]
+ gpu_count * prices["gpu"],
5,
),
cpu=int(tds[2].text),
memory=float(tds[3].text),
gpu_count=gpu_count,
gpu_name=gpu_name,
gpu_memory=int(tds[1].text) / gpu_count,
spot=False,
)
)
return items

def get_cpu_platforms(
self, zone: str, platform_resources: "PlatformResourcePrice"
) -> List[RawCatalogItem]:
logger.info("Fetching CPU platforms")
resp = requests.get("https://nebius.ai/docs/compute/concepts/performance-levels")
resp.raise_for_status()
soup = bs4.BeautifulSoup(resp.text, "html.parser")
configs = (
soup.find(
"p",
string=re.compile(
r"The computing resources may have the following configurations:"
),
)
.find_next_sibling("ul")
.find_all("li")
)
items = []
for li in configs:
platform = li.find("p").find("code").text
prices = platform_resources[platform]
tds = li.find("tbody").find("td", string="100%").find_next_siblings("td")
cpus = [int(i) for i in tds[0].text.translate({"\n": "", " ": ""}).split(",")]
ratios = [float(i) for i in tds[1].text.translate({"\n": "", " ": ""}).split(",")]
for ratio in ratios:
for cpu in cpus:
items.append(
RawCatalogItem(
instance_name=platform,
location=zone,
price=round(cpu * prices["cpu"] + cpu * ratio * prices["ram"], 5),
cpu=cpu,
memory=cpu * ratio,
gpu_count=0,
gpu_name=None,
gpu_memory=None,
spot=False,
)
)
return items

def aggregate_skus(self, skus: List[dict]) -> "PlatformResourcePrice":
vm_resources = {
"GPU": "gpu",
"RAM": "ram",
"100% vCPU": "cpu",
}
vm_name_re = re.compile(
r"((?:Intel|AMD) .+?)(?: with Nvidia (.+))?"
rf"\. ({'|'.join(vm_resources)})(?: — (preemptible).*)?$"
)
platform_resources = defaultdict(dict)
for sku in skus:
if (r := vm_name_re.match(sku["name"])) is None:
continue # storage, images, snapshots, infiniband
cpu_name, gpu_name, resource_name, spot = r.groups()
if spot is not None:
continue
if gpu_name not in GPU_NAME_PLATFORM:
logger.warning("Unknown GPU name: %s", gpu_name)
continue
platform_resources[GPU_NAME_PLATFORM[gpu_name]][
vm_resources[resource_name]
] = self.get_sku_price(sku["pricingVersions"])

return platform_resources

def get_sku_price(self, pricing_versions: List[dict]) -> Optional[float]:
now = datetime.datetime.now(datetime.timezone.utc)
price = None
for version in sorted(pricing_versions, key=lambda p: p["effectiveTime"]):
# I guess it's the price for on-demand instances
if version["type"] != "STREET_PRICE":
continue
if datetime.datetime.fromisoformat(version["effectiveTime"]) > now:
break
# I guess we should take the first pricing expression
price = float(version["pricingExpressions"][0]["rates"][0]["unitPrice"])
return price


class NebiusAPIClient:
# reference: https://nebius.ai/docs/api-design-guide/
def __init__(self, service_account: "ServiceAccount"):
self._service_account = service_account
self._s = requests.Session()
self._expires_at = 0

def get_token(self):
now = int(time.time())
if now + 60 < self._expires_at:
return
logger.debug("Refreshing IAM token")
expires_at = now + 3600
payload = {
"aud": self.url("iam", "/tokens"),
"iss": self._service_account["service_account_id"],
"iat": now,
"exp": expires_at,
}
jwt_token = jwt.encode(
payload,
self._service_account["private_key"],
algorithm="PS256",
headers={"kid": self._service_account["id"]},
)

resp = requests.post(payload["aud"], json={"jwt": jwt_token})
resp.raise_for_status()
iam_token = resp.json()["iamToken"]
self._s.headers["Authorization"] = f"Bearer {iam_token}"
self._expires_at = expires_at

def billing_skus_list(
self,
filter: Optional[str] = None,
page_size: Optional[int] = 1000,
page_token: Optional[str] = None,
) -> "BillingSkusListResponse":
logger.debug("Fetching SKUs")
params = {
"currency": "USD",
"pageSize": page_size,
}
if filter is not None:
params["filter"] = filter
if page_token is not None:
params["pageToken"] = page_token
self.get_token()
resp = self._s.get(self.url("billing", "/skus"), params=params)
resp.raise_for_status()
return resp.json()

def compute_zones_list(self) -> List[dict]:
logger.debug("Fetching compute zones")
self.get_token()
resp = self._s.get(self.url("compute", "/zones"))
resp.raise_for_status()
return resp.json()["zones"]

def url(self, service: str, path: str, version="v1") -> str:
return f"https://{service}.{API_URL.rstrip('/')}/{service}/{version}/{path.lstrip('/')}"


class ServiceAccount(TypedDict):
id: str
service_account_id: str
created_at: str
key_algorithm: str
public_key: str
private_key: str


class BillingSkusListResponse(TypedDict):
skus: List[dict]
nextPageToken: Optional[str]


PlatformResourcePrice = Dict[str, Dict[Literal["cpu", "ram", "gpu"], float]]
Loading

0 comments on commit 5cb42c2

Please sign in to comment.