Skip to content

Commit

Permalink
feat: module lookup for building trackers
Browse files Browse the repository at this point in the history
  • Loading branch information
azzhipa committed Dec 19, 2023
1 parent fad46b6 commit bc135cf
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 16 deletions.
27 changes: 14 additions & 13 deletions torchx/tracker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from functools import lru_cache
from typing import Iterable, Mapping, Optional

from torchx.util.entrypoints import load_group
from torchx.util.entrypoints import load_group, load_module

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -181,26 +181,27 @@ def build_trackers(
) -> Iterable[TrackerBase]:
trackers = []

entrypoint_factories = load_group("torchx.tracker")
entrypoint_factories = load_group("torchx.tracker") or {}
if not entrypoint_factories:
logger.warning(
"No 'torchx.tracker' entry_points are defined. Tracking will not capture any data."
)
return trackers
logger.warning("No 'torchx.tracker' entry_points are defined.")

for entrypoint_key, config in entrypoint_and_config.items():
if entrypoint_key not in entrypoint_factories:
logger.warning(
f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..."
factory = entrypoint_factories.get(entrypoint_key)
if factory is None:
logger.info(
f"Could not find `{entrypoint_key}` among entry_points, checking modules..."
)
continue
factory = entrypoint_factories[entrypoint_key]
factory = load_module(entrypoint_key)
if not factory:
logger.warning(
f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..."
)
continue
if config:
logger.info(f"Tracker config found for `{entrypoint_key}` as `{config}`")
tracker = factory(config)
else:
logger.info(f"No tracker config specified for `{entrypoint_key}`")
tracker = factory(None)
tracker = factory(config)
trackers.append(tracker)
return trackers

Expand Down
22 changes: 21 additions & 1 deletion torchx/tracker/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import defaultdict
from typing import cast, DefaultDict, Dict, Iterable, Mapping, Optional, Tuple
from unittest import mock, TestCase
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from torchx.tracker import app_run_from_env
from torchx.tracker.api import (
Expand All @@ -27,6 +27,8 @@
TrackerSource,
)

from torchx.tracker.mlflow import MLflowTracker

RunId = str

DEFAULT_SOURCE: str = "__parent__"
Expand Down Expand Up @@ -271,6 +273,24 @@ def test_build_trackers_with_no_entrypoints_group_defined(self) -> None:
trackers = build_trackers(tracker_names)
self.assertEqual(0, len(list(trackers)))

def test_build_trackers_with_module(self) -> None:
module = MagicMock()
module.return_value = MagicMock(spec=MLflowTracker)
with patch(
"torchx.tracker.api.load_group",
return_value=None,
) and patch(
"torchx.tracker.api.load_module",
return_value=module,
):
tracker_names = {
"torchx.tracker.mlflow:create_tracker": (config := "myconfig.txt")
}
trackers = build_trackers(tracker_names)
self.assertEqual(1, len(list(trackers)))
self.assertIsInstance(trackers[0], MLflowTracker)
module.assert_called_once_with(config)

def test_build_trackers(self) -> None:
with patch(
"torchx.tracker.api.load_group",
Expand Down
18 changes: 17 additions & 1 deletion torchx/util/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional
import importlib
from types import ModuleType
from typing import Any, Callable, Dict, Optional, Union

import importlib_metadata as metadata
from importlib_metadata import EntryPoint
Expand Down Expand Up @@ -94,3 +96,17 @@ def load_group(
for ep in entrypoints:
eps[ep.name] = _defer_load_ep(ep)
return eps


def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., Any]]]:
parts = path.split(":", 2)
module_path, method = parts[0], parts[1] if len(parts) > 1 else None
i, n = -1, len(module_path)
try:
while i < n:
i = module_path.find(".", i + 1)
i = i if i >= 0 else n
module = importlib.import_module(module_path[:i])
return getattr(module, method) if method else module
except Exception:
return None
14 changes: 13 additions & 1 deletion torchx/util/test/entrypoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from importlib_metadata import EntryPoint, EntryPoints

from torchx.util.entrypoints import load, load_group
from torchx.util.entrypoints import load, load_group, load_module


def EntryPoint_from_config(config: ConfigParser) -> List[EntryPoint]:
Expand Down Expand Up @@ -127,3 +127,15 @@ def test_load_group_missing(self, _: MagicMock) -> None:

with self.assertRaises(ModuleNotFoundError):
load_group("ep.grp.missing.mod.test")["baz"]()

def test_load_module(self) -> None:
result = load_module("os.path")
import os

self.assertEqual(result, os.path)

def test_load_module_method(self) -> None:
result = load_module("os.path:join")
import os

self.assertEqual(result, os.path.join)

0 comments on commit bc135cf

Please sign in to comment.