Skip to content

Commit

Permalink
Adding Override Option for TorchX Role (#956)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #956

Adds a generic way to override the internal values of the Role. Allows async overriding of role values and enable Async Packaging

Differential Revision: D62591176
  • Loading branch information
andywag authored and facebook-github-bot committed Sep 17, 2024
1 parent b7fd00b commit 4d81c76
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
18 changes: 18 additions & 0 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

# pyre-strict

import asyncio
import copy
import inspect
import json
import re
import typing
Expand Down Expand Up @@ -370,6 +372,22 @@ class Role:
mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field(
default_factory=list
)
overrides: Dict[str, Any] = field(default_factory=dict)

# pyre-ignore
def __getattribute__(self, attrname: str) -> Any:
try:
ov = super().__getattribute__("overrides")
except AttributeError:
ov = {}
if attrname in ov:
if inspect.isawaitable(ov[attrname]):
result = asyncio.get_event_loop().run_until_complete(ov[attrname])
else:
result = ov[attrname]()
setattr(self, attrname, result)
del ov[attrname]
return super().__getattribute__(attrname)

def pre_proc(
self,
Expand Down
23 changes: 23 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import asyncio
import os
import time
import unittest
Expand Down Expand Up @@ -276,6 +277,28 @@ def test_retry_policies(self) -> None:
},
)

def test_override_role(self) -> None:
default = Role(
"foobar",
"torch",
overrides={"image": lambda: "base", "entrypoint": lambda: "nentry"},
)
self.assertEqual("base", default.image)
self.assertEqual("nentry", default.entrypoint)

def test_async_override_role(self) -> None:
async def update(value: str, time_seconds: int) -> str:
await asyncio.sleep(time_seconds)
return value

default = Role(
"foobar",
"torch",
overrides={"image": update("base", 1), "entrypoint": update("nentry", 2)},
)
self.assertEqual("base", default.image)
self.assertEqual("nentry", default.entrypoint)


class AppHandleTest(unittest.TestCase):
def test_parse_malformed_app_handles(self) -> None:
Expand Down

0 comments on commit 4d81c76

Please sign in to comment.