-
Notifications
You must be signed in to change notification settings - Fork 197
Add script to generate boilerplate code for new backend #2397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ce5f9c
0979ecc
84903f9
343ef7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import argparse | ||
from pathlib import Path | ||
|
||
import jinja2 | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="This script generates boilerplate code for a new backend" | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--name", | ||
help=( | ||
"The backend name in CamelCase, e.g. AWS, Runpod, VastAI." | ||
" It'll be used for naming backend classes, models, etc." | ||
), | ||
required=True, | ||
) | ||
args = parser.parse_args() | ||
generate_backend_code(args.name) | ||
|
||
|
||
def generate_backend_code(backend_name: str): | ||
template_dir_path = Path(__file__).parent.parent.joinpath( | ||
"src/dstack/_internal/core/backends/template" | ||
) | ||
env = jinja2.Environment( | ||
loader=jinja2.FileSystemLoader( | ||
searchpath=template_dir_path, | ||
), | ||
keep_trailing_newline=True, | ||
) | ||
backend_dir_path = Path(__file__).parent.parent.joinpath( | ||
f"src/dstack/_internal/core/backends/{backend_name.lower()}" | ||
) | ||
backend_dir_path.mkdir(exist_ok=True) | ||
for filename in ["backend.py", "compute.py", "configurator.py", "models.py"]: | ||
template = env.get_template(f"{filename}.jinja") | ||
with open(backend_dir_path.joinpath(filename), "w+") as f: | ||
f.write(template.render({"backend_name": backend_name})) | ||
backend_dir_path.joinpath("__init__.py").write_text("") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from dstack._internal.core.backends.base.backend import Backend | ||
from dstack._internal.core.backends.{{ backend_name|lower }}.compute import {{ backend_name }}Compute | ||
from dstack._internal.core.backends.{{ backend_name|lower }}.models import {{ backend_name }}Config | ||
from dstack._internal.core.models.backends.base import BackendType | ||
|
||
|
||
class {{ backend_name }}Backend(Backend): | ||
TYPE = BackendType.{{ backend_name|upper }} | ||
COMPUTE_CLASS = {{ backend_name }}Compute | ||
|
||
def __init__(self, config: {{ backend_name }}Config): | ||
self.config = config | ||
self._compute = {{ backend_name }}Compute(self.config) | ||
|
||
def compute(self) -> {{ backend_name }}Compute: | ||
return self._compute |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import List, Optional | ||
|
||
from dstack._internal.core.backends.base.backend import Compute | ||
from dstack._internal.core.backends.base.compute import ( | ||
ComputeWithCreateInstanceSupport, | ||
ComputeWithGatewaySupport, | ||
ComputeWithMultinodeSupport, | ||
ComputeWithPlacementGroupSupport, | ||
ComputeWithPrivateGatewaySupport, | ||
ComputeWithReservationSupport, | ||
ComputeWithVolumeSupport, | ||
) | ||
from dstack._internal.core.backends.base.offers import get_catalog_offers | ||
from dstack._internal.core.backends.{{ backend_name|lower }}.models import {{ backend_name }}Config | ||
from dstack._internal.core.models.backends.base import BackendType | ||
from dstack._internal.core.models.instances import ( | ||
InstanceAvailability, | ||
InstanceConfiguration, | ||
InstanceOfferWithAvailability, | ||
) | ||
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run | ||
from dstack._internal.core.models.volumes import Volume | ||
from dstack._internal.utils.logging import get_logger | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class {{ backend_name }}Compute( | ||
# TODO: Choose ComputeWith* classes to extend and implement | ||
# ComputeWithCreateInstanceSupport, | ||
# ComputeWithMultinodeSupport, | ||
# ComputeWithReservationSupport, | ||
# ComputeWithPlacementGroupSupport, | ||
# ComputeWithGatewaySupport, | ||
# ComputeWithPrivateGatewaySupport, | ||
# ComputeWithVolumeSupport, | ||
Compute, | ||
): | ||
def __init__(self, config: {{ backend_name }}Config): | ||
super().__init__() | ||
self.config = config | ||
|
||
def get_offers( | ||
self, requirements: Optional[Requirements] = None | ||
) -> List[InstanceOfferWithAvailability]: | ||
# If the provider is added to gpuhunt, you'd typically get offers | ||
# using `get_catalog_offers()` and extend them with availability info. | ||
offers = get_catalog_offers( | ||
backend=BackendType.{{ backend_name|upper }}, | ||
locations=self.config.regions or None, | ||
requirements=requirements, | ||
# configurable_disk_size=..., TODO: set in case of boot volume size limits | ||
) | ||
# TODO: Add availability info to offers | ||
return [ | ||
InstanceOfferWithAvailability( | ||
**offer.dict(), | ||
availability=InstanceAvailability.UNKNOWN, | ||
) | ||
for offer in offers | ||
] | ||
|
||
def create_instance( | ||
self, | ||
instance_offer: InstanceOfferWithAvailability, | ||
instance_config: InstanceConfiguration, | ||
) -> JobProvisioningData: | ||
# TODO: Implement if backend supports creating instances (VM-based). | ||
# Delete if backend can only run jobs (container-based). | ||
raise NotImplementedError() | ||
|
||
def run_job( | ||
self, | ||
run: Run, | ||
job: Job, | ||
instance_offer: InstanceOfferWithAvailability, | ||
project_ssh_public_key: str, | ||
project_ssh_private_key: str, | ||
volumes: List[Volume], | ||
) -> JobProvisioningData: | ||
# TODO: Implement if create_instance() is not implemented. Delete otherwise. | ||
raise NotImplementedError() | ||
|
||
def terminate_instance( | ||
self, instance_id: str, region: str, backend_data: Optional[str] = None | ||
): | ||
raise NotImplementedError() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import json | ||
|
||
from dstack._internal.core.backends.base.configurator import ( | ||
BackendRecord, | ||
Configurator, | ||
raise_invalid_credentials_error, | ||
) | ||
from dstack._internal.core.backends.{{ backend_name|lower }}.backend import {{ backend_name }}Backend | ||
from dstack._internal.core.backends.{{ backend_name|lower }}.models import ( | ||
Any{{ backend_name }}BackendConfig, | ||
Any{{ backend_name }}Creds, | ||
{{ backend_name }}BackendConfig, | ||
{{ backend_name }}BackendConfigWithCreds, | ||
{{ backend_name }}Config, | ||
{{ backend_name }}Creds, | ||
{{ backend_name }}StoredConfig, | ||
) | ||
from dstack._internal.core.models.backends.base import ( | ||
BackendType, | ||
) | ||
|
||
# TODO: Add all supported regions and default regions | ||
REGIONS = [] | ||
Comment on lines
+22
to
+23
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoded regions are not needed for most backends, so I wouldn't include this in the template. I guess they used to be needed for interactive setup, but now they are only needed for backends with custom-built VM images that are not available in all regions. For other backends, hardcoded regions are rather harmful, as they prevent users from using newly added regions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regions don't need to be hardcoded but they need to be validated, and hardcoding them seems to be the best option for most GPU clouds (have a few regions, don't add new regions often, unlikely to have an API to get all the regions). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Anyway, feel free to adjust the template on regions. |
||
|
||
|
||
class {{ backend_name }}Configurator(Configurator): | ||
TYPE = BackendType.{{ backend_name|upper }} | ||
BACKEND_CLASS = {{ backend_name }}Backend | ||
|
||
def validate_config( | ||
self, config: {{ backend_name }}BackendConfigWithCreds, default_creds_enabled: bool | ||
): | ||
self._validate_creds(config.creds) | ||
# TODO: Validate additional config parameters if any | ||
|
||
def create_backend( | ||
self, project_name: str, config: {{ backend_name }}BackendConfigWithCreds | ||
) -> BackendRecord: | ||
if config.regions is None: | ||
config.regions = REGIONS | ||
return BackendRecord( | ||
config={{ backend_name }}StoredConfig( | ||
**{{ backend_name }}BackendConfig.__response__.parse_obj(config).dict() | ||
).json(), | ||
auth={{ backend_name }}Creds.parse_obj(config.creds).json(), | ||
) | ||
|
||
def get_backend_config( | ||
self, record: BackendRecord, include_creds: bool | ||
) -> Any{{ backend_name }}BackendConfig: | ||
config = self._get_config(record) | ||
if include_creds: | ||
return {{ backend_name }}BackendConfigWithCreds.__response__.parse_obj(config) | ||
return {{ backend_name }}BackendConfig.__response__.parse_obj(config) | ||
|
||
def get_backend(self, record: BackendRecord) -> {{ backend_name }}Backend: | ||
config = self._get_config(record) | ||
return {{ backend_name }}Backend(config=config) | ||
|
||
def _get_config(self, record: BackendRecord) -> {{ backend_name }}Config: | ||
return {{ backend_name }}Config.__response__( | ||
**json.loads(record.config), | ||
creds={{ backend_name }}Creds.parse_raw(record.auth), | ||
) | ||
|
||
def _validate_creds(self, creds: Any{{ backend_name }}Creds): | ||
# TODO: Implement API key or other creds validation | ||
# if valid: | ||
# return | ||
raise_invalid_credentials_error(fields=[["creds", "api_key"]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Consider adding a sample call to
get_catalog_offers
, as contributors can forget to pass important arguments, such aslocations
orconfigurable_disk_size
.