Skip to content

Commit 4fd830b

Browse files
committed
feat: add ulimits support to aws_batch (#1126)
1 parent b72ba03 commit 4fd830b

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,37 @@
9999
TAG_TORCHX_USER = "torchx.pytorch.org/user"
100100

101101

102+
def parse_ulimits(ulimits_list: list[str]) -> List[Dict[str, Any]]:
103+
"""
104+
Parse ulimit string in format: name:softLimit:hardLimit
105+
Multiple ulimits separated by commas.
106+
"""
107+
if not ulimits_list:
108+
return []
109+
110+
ulimits = []
111+
for ulimit_str in ulimits_list:
112+
if not ulimit_str.strip():
113+
continue
114+
115+
parts = ulimit_str.strip().split(":")
116+
if len(parts) != 3:
117+
raise ValueError(
118+
f"ulimit must be in format name:softLimit:hardLimit, got: {ulimit_str}"
119+
)
120+
121+
name, soft_limit, hard_limit = parts
122+
ulimits.append(
123+
{
124+
"name": name,
125+
"softLimit": int(soft_limit) if soft_limit != "-1" else -1,
126+
"hardLimit": int(hard_limit) if hard_limit != "-1" else -1,
127+
}
128+
)
129+
130+
return ulimits
131+
132+
102133
if TYPE_CHECKING:
103134
from docker import DockerClient
104135

@@ -177,6 +208,7 @@ def _role_to_node_properties(
177208
privileged: bool = False,
178209
job_role_arn: Optional[str] = None,
179210
execution_role_arn: Optional[str] = None,
211+
ulimits: Optional[List[Dict[str, Any]]] = None,
180212
) -> Dict[str, object]:
181213
role.mounts += get_device_mounts(role.resource.devices)
182214

@@ -239,6 +271,7 @@ def _role_to_node_properties(
239271
"environment": [{"name": k, "value": v} for k, v in role.env.items()],
240272
"privileged": privileged,
241273
"resourceRequirements": resource_requirements_from_resource(role.resource),
274+
**({"ulimits": ulimits} if ulimits else {}),
242275
"linuxParameters": {
243276
# To support PyTorch dataloaders we need to set /dev/shm to larger
244277
# than the 64M default.
@@ -361,6 +394,7 @@ class AWSBatchOpts(TypedDict, total=False):
361394
priority: int
362395
job_role_arn: Optional[str]
363396
execution_role_arn: Optional[str]
397+
ulimits: Optional[str]
364398

365399

366400
class AWSBatchScheduler(
@@ -514,6 +548,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
514548
privileged=cfg["privileged"],
515549
job_role_arn=cfg.get("job_role_arn"),
516550
execution_role_arn=cfg.get("execution_role_arn"),
551+
ulimits=parse_ulimits(cfg.get("ulimits") or []),
517552
)
518553
)
519554
node_idx += role.num_replicas
@@ -599,6 +634,11 @@ def _run_opts(self) -> runopts:
599634
type_=str,
600635
help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.",
601636
)
637+
opts.add(
638+
"ulimits",
639+
type_=list[str],
640+
help="Ulimit settings in format: name:softLimit:hardLimit (multiple separated by commas)",
641+
)
602642
return opts
603643

604644
def _get_job_id(self, app_id: str) -> Optional[str]:

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AWSBatchScheduler,
2424
create_scheduler,
2525
ENV_TORCHX_ROLE_NAME,
26+
parse_ulimits,
2627
resource_from_resource_requirements,
2728
resource_requirements_from_resource,
2829
to_millis_since_epoch,
@@ -396,6 +397,46 @@ def test_resource_devices(self) -> None:
396397
],
397398
)
398399

400+
def test_role_to_node_properties_ulimits(self) -> None:
401+
role = specs.Role(
402+
name="test",
403+
image="test:latest",
404+
entrypoint="test",
405+
args=["test"],
406+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
407+
)
408+
ulimits = [
409+
{"name": "nofile", "softLimit": 65536, "hardLimit": 65536},
410+
{"name": "memlock", "softLimit": -1, "hardLimit": -1},
411+
]
412+
props = _role_to_node_properties(role, 0, ulimits=ulimits)
413+
self.assertEqual(
414+
props["container"]["ulimits"],
415+
ulimits,
416+
)
417+
418+
def test_parse_ulimits(self) -> None:
419+
# Test single ulimit
420+
result = parse_ulimits(["nofile:65536:65536"])
421+
expected = [{"name": "nofile", "softLimit": 65536, "hardLimit": 65536}]
422+
self.assertEqual(result, expected)
423+
424+
# Test multiple ulimits
425+
result = parse_ulimits(["nofile:65536:65536", "memlock:-1:-1"])
426+
expected = [
427+
{"name": "nofile", "softLimit": 65536, "hardLimit": 65536},
428+
{"name": "memlock", "softLimit": -1, "hardLimit": -1},
429+
]
430+
self.assertEqual(result, expected)
431+
432+
# Test empty list
433+
result = parse_ulimits([])
434+
self.assertEqual(result, [])
435+
436+
# Test invalid format
437+
with self.assertRaises(ValueError):
438+
parse_ulimits(["invalid"])
439+
399440
def _mock_scheduler_running_job(self) -> AWSBatchScheduler:
400441
scheduler = AWSBatchScheduler(
401442
"test",

0 commit comments

Comments
 (0)