From 407cf84bf83d45f13ff0774a9257b12a49f53c1b Mon Sep 17 00:00:00 2001 From: Ryan Li Date: Fri, 14 Jun 2024 01:24:28 +0000 Subject: [PATCH] support mounting neuron devices for local_docker scheduler --- torchx/schedulers/devices.py | 15 +++++++++++---- .../schedulers/test/aws_batch_scheduler_test.py | 10 +++++++++- torchx/schedulers/test/devices_test.py | 3 ++- torchx/schedulers/test/docker_scheduler_test.py | 11 +++++++++-- torchx/specs/named_resources_aws.py | 9 +++++++-- torchx/specs/test/named_resources_aws_test.py | 3 +++ 6 files changed, 41 insertions(+), 10 deletions(-) diff --git a/torchx/schedulers/devices.py b/torchx/schedulers/devices.py index 2491c2c68..8656e69e9 100644 --- a/torchx/schedulers/devices.py +++ b/torchx/schedulers/devices.py @@ -7,25 +7,32 @@ # pyre-strict import warnings +from functools import partial from typing import Callable, Dict, List, Mapping from torchx.specs.api import DeviceMount +from torchx.specs.named_resources_aws import EFA_DEVICE, NEURON_DEVICE -def efa_to_devicemounts(num_devices: int) -> List[DeviceMount]: +def to_devicemounts(num_devices: int, device_type: str) -> List[DeviceMount]: device_mounts = [] for device_index in range(0, num_devices): device_mounts.append( DeviceMount( - src_path="/dev/infiniband/uverbs" + str(device_index), - dst_path="/dev/infiniband/uverbs" + str(device_index), + src_path=device_type + str(device_index), + dst_path=device_type + str(device_index), ) ) return device_mounts +neuron_to_devicemounts: Callable[[int], List[DeviceMount]] = partial(to_devicemounts, device_type="/dev/neuron") +efa_to_devicemounts: Callable[[int], List[DeviceMount]] = partial(to_devicemounts, device_type="/dev/infiniband/uverbs") + + DEVICES: Mapping[str, Callable[[int], List[DeviceMount]]] = { - "vpc.amazonaws.com/efa": efa_to_devicemounts, + EFA_DEVICE: efa_to_devicemounts, + NEURON_DEVICE: neuron_to_devicemounts, } diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index d9a2914aa..482116eb8 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -361,7 +361,10 @@ def test_resource_devices(self) -> None: image="", mounts=[], resource=specs.Resource( - cpu=1, memMB=1000, gpu=0, devices={"vpc.amazonaws.com/efa": 2} + cpu=1, + memMB=1000, + gpu=0, + devices={"vpc.amazonaws.com/efa": 2, "aws.amazon.com/neurondevice": 1}, ), ) props = _role_to_node_properties(role, 0) @@ -379,6 +382,11 @@ def test_resource_devices(self) -> None: "containerPath": "/dev/infiniband/uverbs1", "permissions": ["READ", "WRITE", "MKNOD"], }, + { + "hostPath": "/dev/neuron0", + "containerPath": "/dev/neuron0", + "permissions": ["READ", "WRITE", "MKNOD"], + }, ], ) diff --git a/torchx/schedulers/test/devices_test.py b/torchx/schedulers/test/devices_test.py index f6099b83f..c86d58e1a 100644 --- a/torchx/schedulers/test/devices_test.py +++ b/torchx/schedulers/test/devices_test.py @@ -16,7 +16,7 @@ class DevicesTest(unittest.TestCase): def test_get_efa(self) -> None: - devices = {"vpc.amazonaws.com/efa": 2} + devices = {"vpc.amazonaws.com/efa": 2, "aws.amazon.com/neurondevice": 1} self.assertEqual( get_device_mounts(devices), [ @@ -28,6 +28,7 @@ def test_get_efa(self) -> None: src_path="/dev/infiniband/uverbs1", dst_path="/dev/infiniband/uverbs1", ), + DeviceMount(src_path="/dev/neuron0", dst_path="/dev/neuron0"), ], ) diff --git a/torchx/schedulers/test/docker_scheduler_test.py b/torchx/schedulers/test/docker_scheduler_test.py index 59839b1f4..d5c9b3073 100644 --- a/torchx/schedulers/test/docker_scheduler_test.py +++ b/torchx/schedulers/test/docker_scheduler_test.py @@ -161,12 +161,19 @@ def test_device_mounts(self) -> None: def test_resource_devices(self) -> None: app = _test_app() app.roles[0].mounts = [] - app.roles[0].resource.devices = {"vpc.amazonaws.com/efa": 1} + app.roles[0].resource.devices = { + "vpc.amazonaws.com/efa": 1, + "aws.amazon.com/neurondevice": 2, + } info = self.scheduler.submit_dryrun(app, cfg={}) self.assertEqual( info.request.containers[0].kwargs["devices"], - ["/dev/infiniband/uverbs0:/dev/infiniband/uverbs0:rwm"], + [ + "/dev/infiniband/uverbs0:/dev/infiniband/uverbs0:rwm", + "/dev/neuron0:/dev/neuron0:rwm", + "/dev/neuron1:/dev/neuron1:rwm", + ], ) @patch("os.environ", {"FOO_1": "f1", "BAR_1": "b1", "FOOBAR_1": "fb1"}) diff --git a/torchx/specs/named_resources_aws.py b/torchx/specs/named_resources_aws.py index 1ddb681a8..cbd69988b 100644 --- a/torchx/specs/named_resources_aws.py +++ b/torchx/specs/named_resources_aws.py @@ -37,6 +37,7 @@ from torchx.specs.api import Resource EFA_DEVICE = "vpc.amazonaws.com/efa" +NEURON_DEVICE = "aws.amazon.com/neurondevice" # ecs and ec2 have memtax and currently AWS Batch uses hard memory limits # so we have to account for mem tax when registering these resources for AWS @@ -255,7 +256,11 @@ def aws_g5_48xlarge() -> Resource: def aws_trn1_2xlarge() -> Resource: return Resource( - cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xlarge"} + cpu=8, + gpu=0, + memMB=32 * GiB, + capabilities={K8S_ITYPE: "trn1.2xlarge"}, + devices={NEURON_DEVICE: 1}, ) @@ -265,7 +270,7 @@ def aws_trn1_32xlarge() -> Resource: gpu=0, memMB=512 * GiB, capabilities={K8S_ITYPE: "trn1.32xlarge"}, - devices={EFA_DEVICE: 8}, + devices={EFA_DEVICE: 8, NEURON_DEVICE: 16}, ) diff --git a/torchx/specs/test/named_resources_aws_test.py b/torchx/specs/test/named_resources_aws_test.py index b043d32af..064483f8f 100644 --- a/torchx/specs/test/named_resources_aws_test.py +++ b/torchx/specs/test/named_resources_aws_test.py @@ -38,6 +38,7 @@ GiB, K8S_ITYPE, NAMED_RESOURCES, + NEURON_DEVICE, ) @@ -170,11 +171,13 @@ def test_aws_trn1(self) -> None: self.assertEqual(8, trn1_2.cpu) self.assertEqual(0, trn1_2.gpu) self.assertEqual(32 * GiB, trn1_2.memMB) + self.assertEqual({NEURON_DEVICE: 1}, trn1_2.devices) trn1_32 = aws_trn1_32xlarge() self.assertEqual(trn1_32.cpu, trn1_2.cpu * 16) self.assertEqual(trn1_32.gpu, trn1_2.gpu) self.assertEqual(trn1_32.memMB, trn1_2.memMB * 16) + self.assertEqual({EFA_DEVICE: 8, NEURON_DEVICE: 16}, trn1_32.devices) def test_aws_m5_2xlarge(self) -> None: resource = aws_m5_2xlarge()