Skip to content

Commit

Permalink
fixes #780 (#781)
Browse files Browse the repository at this point in the history
add instance type for aws_batch_scheduler multinode jobs

Co-authored-by: Alexander Jipa <[email protected]>
  • Loading branch information
Alexander Jipa and azzhipa authored Nov 3, 2023
1 parent fe6354d commit b1e56b2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
5 changes: 5 additions & 0 deletions torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
runopts,
VolumeMount,
)
from torchx.specs.named_resources_aws import instance_type_from_resource
from torchx.util.types import none_throws
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
from typing_extensions import TypedDict
Expand Down Expand Up @@ -244,6 +245,10 @@ def _role_to_node_properties(
"mountPoints": mount_points,
"volumes": volumes,
}
if role.num_replicas > 1:
instance_type = instance_type_from_resource(role.resource)
if instance_type is not None:
container["instanceType"] = instance_type

return {
"targetNodes": f"{start_idx}:{start_idx + role.num_replicas - 1}",
Expand Down
39 changes: 36 additions & 3 deletions torchx/schedulers/test/aws_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from torchx.specs import AppState, Resource


def _test_app() -> specs.AppDef:
def _test_app(
num_replicas: int = 2, resource: Optional[Resource] = None
) -> specs.AppDef:
trainer_role = specs.Role(
name="trainer",
image="pytorch/torchx:latest",
Expand All @@ -41,13 +43,14 @@ def _test_app() -> specs.AppDef:
f" --rank0_host $${{{specs.macros.rank0_env}:=localhost}}",
],
env={"FOO": "bar"},
resource=specs.Resource(
resource=resource
or specs.Resource(
cpu=2,
memMB=3000,
gpu=4,
),
port_map={"foo": 1234},
num_replicas=2,
num_replicas=num_replicas,
max_retries=3,
mounts=[
specs.BindMount(src_path="/src", dst_path="/dst", read_only=True),
Expand Down Expand Up @@ -156,6 +159,36 @@ def test_submit_dryrun_privileged(self) -> None:
self.assertEqual(1, len(node_groups))
self.assertTrue(node_groups[0]["container"]["privileged"])

def test_submit_dryrun_instance_type_multinode(self) -> None:
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=2, resource=resource)
info = create_scheduler("test").submit_dryrun(app, cfg)
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertEqual(
resource.capabilities[specs.named_resources_aws.K8S_ITYPE],
node_groups[0]["container"]["instanceType"],
)

def test_submit_dryrun_no_instance_type_singlenode(self) -> None:
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=1, resource=resource)
info = create_scheduler("test").submit_dryrun(app, cfg)
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertTrue("instanceType" not in node_groups[0]["container"])

def test_submit_dryrun_no_instance_type_non_aws(self) -> None:
cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})
resource = specs.named_resources_aws.aws_p3dn_24xlarge()
app = _test_app(num_replicas=2)
info = create_scheduler("test").submit_dryrun(app, cfg)
node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"]
self.assertEqual(1, len(node_groups))
self.assertTrue("instanceType" not in node_groups[0]["container"])

@mock_rand()
def test_submit_dryrun(self) -> None:
cfg = AWSBatchOpts({"queue": "testqueue", "user": "testuser"})
Expand Down
13 changes: 13 additions & 0 deletions torchx/specs/named_resources_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import warnings
from typing import Callable, Mapping

from torchx.specs.api import Resource
Expand All @@ -41,10 +42,22 @@
# 97% is based on empirical observation that works well for most instance types
# see: https://docs.aws.amazon.com/batch/latest/userguide/memory-management.html
MEM_TAX = 0.97

# determines instance type for non-honogeneous CEs
# see https://github.com/pytorch/torchx/issues/780
K8S_ITYPE = "node.kubernetes.io/instance-type"
GiB: int = int(1024 * MEM_TAX)


def instance_type_from_resource(resource: Resource) -> str:
instance_type = resource.capabilities.get(K8S_ITYPE)
if instance_type is None:
warnings.warn(
"Cannot determine resource instance type which can cause issues for non-homogeneous CEs and multinode jobs. Consider providing torchx.specs.named_resources_aws:K8S_TYPE resource capability."
)
return instance_type


def aws_p3_2xlarge() -> Resource:
return Resource(
cpu=8, gpu=1, memMB=61 * GiB, capabilities={K8S_ITYPE: "p3.2xlarge"}
Expand Down

0 comments on commit b1e56b2

Please sign in to comment.