Skip to content

Commit

Permalink
Add Vultr Support
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihan Rana authored and Bihan Rana committed Dec 23, 2024
1 parent 0101ba0 commit d4abe00
Show file tree
Hide file tree
Showing 13 changed files with 474 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BackendType.LAMBDA,
BackendType.OCI,
BackendType.TENSORDOCK,
BackendType.VULTR,
]
BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT = [
BackendType.AWS,
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/core/backends/vultr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dstack._internal.core.backends.base import Backend
from dstack._internal.core.backends.vultr.compute import VultrCompute
from dstack._internal.core.backends.vultr.config import VultrConfig
from dstack._internal.core.models.backends.base import BackendType


class VultrBackend(Backend):
TYPE: BackendType = BackendType.VULTR

def __init__(self, config: VultrConfig):
self.config = config
self._compute = VultrCompute(self.config)

def compute(self) -> VultrCompute:
return self._compute
154 changes: 154 additions & 0 deletions src/dstack/_internal/core/backends/vultr/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import base64
from typing import Any, List

import requests
from requests import Response

from dstack._internal.core.errors import BackendInvalidCredentialsError

API_URL = "https://api.vultr.com/v2"


class VultrApiClient:
def __init__(self, api_key: str):
self.api_key = api_key

def validate_api_key(self) -> bool:
try:
self._make_request("GET", "/ssh-keys")
except BackendInvalidCredentialsError:
return False
return True

def get_instance(self, instance_id: str, plan_type: str):
if plan_type == "bare-metal":
response = self._make_request("GET", f"/bare-metals/{instance_id}")
return response.json()["bare_metal"]
else:
response = self._make_request("GET", f"/instances/{instance_id}")
return response.json()["instance"]

def launch_instance(
self, region: str, plan: str, label: str, startup_script: str, public_keys: List[str]
):
# Fetch or create startup script ID
script_id: str = self.get_startup_script_id(startup_script)
# Fetch or create SSH key IDs
sshkey_ids: List[str] = self.get_sshkey_id(public_keys)
# For Bare-metals
if "vbm" in plan:
# "Docker on Ubuntu 22.04" is required for bare-metals.
data = {
"region": region,
"plan": plan,
"label": label,
"image_id": "docker",
"script_id": script_id,
"sshkey_id": sshkey_ids,
}
resp = self._make_request("POST", "/bare-metals", data)
return resp.json()["bare_metal"]["id"]
# For VMs
elif "vcg" in plan:
# Ubuntu 22.04 will be installed. For gpu VMs, docker is preinstalled.
data = {
"region": region,
"plan": plan,
"label": label,
"os_id": 1743,
"script_id": script_id,
"sshkey_id": sshkey_ids,
}
resp = self._make_request("POST", "/instances", data)
return resp.json()["instance"]["id"]
else:
data = {
"region": region,
"plan": plan,
"label": label,
"image_id": "docker",
"script_id": script_id,
"sshkey_id": sshkey_ids,
}
resp = self._make_request("POST", "/instances", data)
return resp.json()["instance"]["id"]

def get_startup_script_id(self, startup_script: str) -> str:
script_name = "dstack-shim-script"
encoded_script = base64.b64encode(startup_script.encode()).decode()

# Get the list of startup scripts
response = self._make_request("GET", "/startup-scripts")
scripts = response.json()["startup_scripts"]

# Find the script by name
existing_script = next((s for s in scripts if s["name"] == script_name), None)

if existing_script:
# Update the existing script
startup_id = existing_script["id"]
update_payload = {
"name": script_name,
"script": encoded_script,
}
self._make_request("PATCH", f"/startup-scripts/{startup_id}", update_payload)
else:
# Create a new script
create_payload = {
"name": script_name,
"type": "boot",
"script": encoded_script,
}
create_response = self._make_request("POST", "/startup-scripts", create_payload)
startup_id = create_response.json()["startup_script"]["id"]

return startup_id

def get_sshkey_id(self, ssh_ids: List[str]) -> List[str]:
# Fetch existing SSH keys
response = self._make_request("GET", "/ssh-keys")
ssh_keys = response.json()["ssh_keys"]

ssh_key_ids = []
existing_keys = {key["ssh_key"]: key["id"] for key in ssh_keys}

for ssh_key in ssh_ids:
if ssh_key in existing_keys:
# SSH key already exists, add its id to the list
ssh_key_ids.append(existing_keys[ssh_key])
else:
# Create new SSH key
create_payload = {"name": "dstack-ssh-key", "ssh_key": ssh_key}
create_response = self._make_request("POST", "/ssh-keys", create_payload)
new_ssh_key_id = create_response.json()["ssh_key"]["id"]
ssh_key_ids.append(new_ssh_key_id)

return ssh_key_ids

def terminate_instance(self, instance_id: str, plan_type: str):
if plan_type == "bare-metal":
# Terminate bare-metal instance
endpoint = f"/bare-metals/{instance_id}"
else:
# Terminate virtual machine instance
endpoint = f"/instances/{instance_id}"
self._make_request("DELETE", endpoint)

def _make_request(self, method: str, path: str, data: Any = None) -> Response:
try:
response = requests.request(
method=method,
url=API_URL + path,
json=data,
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=30,
)
response.raise_for_status()
return response
except requests.HTTPError as e:
if e.response is not None and e.response.status_code in (
requests.codes.forbidden,
requests.codes.unauthorized,
):
raise BackendInvalidCredentialsError(e.response.text)
raise
128 changes: 128 additions & 0 deletions src/dstack/_internal/core/backends/vultr/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import json
from typing import List, Optional

import requests

from dstack._internal.core.backends.base import Compute
from dstack._internal.core.backends.base.compute import (
get_instance_name,
get_shim_commands,
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.vultr.api_client import VultrApiClient
from dstack._internal.core.backends.vultr.config import VultrConfig
from dstack._internal.core.errors import BackendError, ProvisioningError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
InstanceOfferWithAvailability,
SSHKey,
)
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import Volume
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


class VultrCompute(Compute):
def __init__(self, config: VultrConfig):
self.config = config
self.api_client = VultrApiClient(config.creds.api_key)

def get_offers(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
offers = get_catalog_offers(
backend=BackendType.VULTR,
requirements=requirements,
)
offers = [
InstanceOfferWithAvailability(
**offer.dict(), availability=InstanceAvailability.AVAILABLE
)
for offer in offers
]
return offers

def run_job(
self,
run: Run,
job: Job,
instance_offer: InstanceOfferWithAvailability,
project_ssh_public_key: str,
project_ssh_private_key: str,
volumes: List[Volume],
) -> JobProvisioningData:
instance_config = InstanceConfiguration(
project_name=run.project_name,
instance_name=get_instance_name(run, job),
ssh_keys=[SSHKey(public=project_ssh_public_key.strip())],
user=run.user,
)
return self.create_instance(instance_offer, instance_config)

def create_instance(
self, instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration
) -> JobProvisioningData:
public_keys = instance_config.get_public_keys()
commands = get_shim_commands(authorized_keys=public_keys)
shim_commands = "#!/bin/sh\n" + " ".join([" && ".join(commands)])
try:
instance_id = self.api_client.launch_instance(
region=instance_offer.region,
label=instance_config.instance_name,
plan=instance_offer.instance.name,
startup_script=shim_commands,
public_keys=public_keys,
)
except KeyError as e:
raise BackendError(e)

launched_instance = JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=instance_id,
hostname=None,
internal_ip=None,
region=instance_offer.region,
price=instance_offer.price,
ssh_port=22,
username="root",
ssh_proxy=None,
dockerized=True,
backend_data=json.dumps(
{
"plan_type": "bare-metal"
if "vbm" in instance_offer.instance.name
else "vm_instance"
}
),
)
return launched_instance

def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
) -> None:
plan_type = json.loads(backend_data)["plan_type"]
try:
self.api_client.terminate_instance(instance_id=instance_id, plan_type=plan_type)
except requests.HTTPError as e:
raise BackendError(e.response.text)

def update_provisioning_data(
self,
provisioning_data: JobProvisioningData,
project_ssh_public_key: str,
project_ssh_private_key: str,
):
plan_type = json.loads(provisioning_data.backend_data)["plan_type"]
instance_data = self.api_client.get_instance(provisioning_data.instance_id, plan_type)
# Access specific fields
instance_status = instance_data["status"]
instance_main_ip = instance_data["main_ip"]
if instance_status == "active":
provisioning_data.hostname = instance_main_ip
if instance_status == "failed":
raise ProvisioningError("VM entered FAILED state")
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/backends/vultr/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dstack._internal.core.backends.base.config import BackendConfig
from dstack._internal.core.models.backends.vultr import (
AnyVultrCreds,
VultrStoredConfig,
)


class VultrConfig(VultrStoredConfig, BackendConfig):
creds: AnyVultrCreds
10 changes: 10 additions & 0 deletions src/dstack/_internal/core/models/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
VastAIConfigInfoWithCredsPartial,
VastAIConfigValues,
)
from dstack._internal.core.models.backends.vultr import (
VultrConfigInfo,
VultrConfigInfoWithCreds,
VultrConfigInfoWithCredsPartial,
VultrConfigValues,
)
from dstack._internal.core.models.common import CoreModel

# The following models are the basis of the JSON-based backend API.
Expand All @@ -100,6 +106,7 @@
RunpodConfigInfo,
TensorDockConfigInfo,
VastAIConfigInfo,
VultrConfigInfo,
DstackConfigInfo,
DstackBaseBackendConfigInfo,
]
Expand All @@ -120,6 +127,7 @@
RunpodConfigInfoWithCreds,
TensorDockConfigInfoWithCreds,
VastAIConfigInfoWithCreds,
VultrConfigInfoWithCreds,
DstackConfigInfo,
]

Expand All @@ -141,6 +149,7 @@
RunpodConfigInfoWithCredsPartial,
TensorDockConfigInfoWithCredsPartial,
VastAIConfigInfoWithCredsPartial,
VultrConfigInfoWithCredsPartial,
DstackConfigInfo,
]

Expand All @@ -158,6 +167,7 @@
RunpodConfigValues,
TensorDockConfigValues,
VastAIConfigValues,
VultrConfigValues,
DstackConfigValues,
]

Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class BackendType(str, enum.Enum):
RUNPOD (BackendType): Runpod Cloud
TENSORDOCK (BackendType): TensorDock Marketplace
VASTAI (BackendType): Vast.ai Marketplace
VULTR (BackendType): Vultr
"""

AWS = "aws"
Expand All @@ -35,6 +36,7 @@ class BackendType(str, enum.Enum):
RUNPOD = "runpod"
TENSORDOCK = "tensordock"
VASTAI = "vastai"
VULTR = "vultr"


class ConfigElementValue(CoreModel):
Expand Down
Loading

0 comments on commit d4abe00

Please sign in to comment.