Skip to content

Commit a264c75

Browse files
authored
(torchx/specs) Add TORCHX_HOME function that returns the dot-torchx directory
Differential Revision: D83575373 Pull Request resolved: #1133
1 parent dc79474 commit a264c75

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

torchx/specs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
RoleStatus,
4242
runopt,
4343
runopts,
44+
TORCHX_HOME,
4445
UnknownAppException,
4546
UnknownSchedulerException,
4647
VolumeMount,
@@ -53,6 +54,7 @@
5354

5455
GiB: int = 1024
5556

57+
5658
ResourceFactory = Callable[[], Resource]
5759

5860
AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(

torchx/specs/api.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import inspect
1212
import json
1313
import logging as logger
14+
import os
15+
import pathlib
1416
import re
1517
import typing
1618
from dataclasses import asdict, dataclass, field
@@ -66,6 +68,32 @@
6668
RESET = "\033[0m"
6769

6870

71+
def TORCHX_HOME(*subdir_paths: str) -> pathlib.Path:
72+
"""
73+
Path to the "dot-directory" for torchx.
74+
Defaults to `~/.torchx` and is overridable via the `TORCHX_HOME` environment variable.
75+
76+
Usage:
77+
78+
.. doc-test::
79+
80+
from pathlib import Path
81+
from torchx.specs import TORCHX_HOME
82+
83+
assert TORCHX_HOME() == Path.home() / ".torchx"
84+
assert TORCHX_HOME("conda-pack-out") == Path.home() / ".torchx" / "conda-pack-out"
85+
```
86+
"""
87+
88+
default_dir = str(pathlib.Path.home() / ".torchx")
89+
torchx_home = pathlib.Path(os.getenv("TORCHX_HOME", default_dir))
90+
91+
torchx_home = torchx_home / os.path.sep.join(subdir_paths)
92+
torchx_home.mkdir(parents=True, exist_ok=True)
93+
94+
return torchx_home
95+
96+
6997
# ========================================
7098
# ==== Distributed AppDef API =======
7199
# ========================================

torchx/specs/test/api_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
import asyncio
1111
import concurrent
1212
import os
13+
import tempfile
1314
import time
1415
import unittest
1516
from dataclasses import asdict
17+
from pathlib import Path
1618
from typing import Dict, List, Mapping, Tuple, Union
19+
from unittest import mock
1720
from unittest.mock import MagicMock
1821

1922
import torchx.specs.named_resources_aws as named_resources_aws
@@ -40,9 +43,37 @@
4043
RoleStatus,
4144
runopt,
4245
runopts,
46+
TORCHX_HOME,
4347
)
4448

4549

50+
class TorchXHomeTest(unittest.TestCase):
51+
# guard against TORCHX_HOME set outside the test
52+
@mock.patch.dict(os.environ, {}, clear=True)
53+
def test_TORCHX_HOME_default(self) -> None:
54+
with tempfile.TemporaryDirectory() as tmpdir:
55+
user_home = Path(tmpdir) / "sally"
56+
with mock.patch("pathlib.Path.home", return_value=user_home):
57+
torchx_home = TORCHX_HOME()
58+
self.assertEqual(torchx_home, user_home / ".torchx")
59+
self.assertTrue(torchx_home.exists())
60+
61+
def test_TORCHX_HOME_override(self) -> None:
62+
with tempfile.TemporaryDirectory() as tmpdir:
63+
override_torchx_home = Path(tmpdir) / "test" / ".torchx"
64+
with mock.patch.dict(
65+
os.environ, {"TORCHX_HOME": str(override_torchx_home)}
66+
):
67+
torchx_home = TORCHX_HOME()
68+
conda_pack_out = TORCHX_HOME("conda-pack", "out")
69+
70+
self.assertEqual(override_torchx_home, torchx_home)
71+
self.assertEqual(torchx_home / "conda-pack" / "out", conda_pack_out)
72+
73+
self.assertTrue(torchx_home.is_dir())
74+
self.assertTrue(conda_pack_out.is_dir())
75+
76+
4677
class AppDryRunInfoTest(unittest.TestCase):
4778
def test_repr(self) -> None:
4879
request_mock = MagicMock()

0 commit comments

Comments
 (0)