Skip to content

Commit

Permalink
adding module lookup for building trackers
Browse files Browse the repository at this point in the history
  • Loading branch information
azzhipa committed Dec 19, 2023
1 parent fad46b6 commit 9c90c3e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
28 changes: 16 additions & 12 deletions torchx/tracker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from functools import lru_cache
from typing import Iterable, Mapping, Optional

import hydra

from torchx.util.entrypoints import load_group

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -181,26 +183,28 @@ def build_trackers(
) -> Iterable[TrackerBase]:
trackers = []

entrypoint_factories = load_group("torchx.tracker")
entrypoint_factories = load_group("torchx.tracker") or {}
if not entrypoint_factories:
logger.warning(
"No 'torchx.tracker' entry_points are defined. Tracking will not capture any data."
)
return trackers
logger.warning("No 'torchx.tracker' entry_points are defined.")

for entrypoint_key, config in entrypoint_and_config.items():
if entrypoint_key not in entrypoint_factories:
logger.warning(
f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..."
factory = entrypoint_factories.get(entrypoint_key)
if factory is None:
logger.info(
f"Could not find `{entrypoint_key}` among entry_points, checking modules..."
)
continue
factory = entrypoint_factories[entrypoint_key]
try:
factory = hydra.utils.get_method(entrypoint_key)
except Exception:
logger.warning(
f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..."
)
continue
if config:
logger.info(f"Tracker config found for `{entrypoint_key}` as `{config}`")
tracker = factory(config)
else:
logger.info(f"No tracker config specified for `{entrypoint_key}`")
tracker = factory(None)
tracker = factory(config)
trackers.append(tracker)
return trackers

Expand Down
12 changes: 12 additions & 0 deletions torchx/tracker/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
TrackerSource,
)

from torchx.tracker.mlflow import MLflowTracker

RunId = str

DEFAULT_SOURCE: str = "__parent__"
Expand Down Expand Up @@ -271,6 +273,16 @@ def test_build_trackers_with_no_entrypoints_group_defined(self) -> None:
trackers = build_trackers(tracker_names)
self.assertEqual(0, len(list(trackers)))

def test_build_trackers_with_module(self) -> None:
with patch(
"torchx.tracker.api.load_group",
return_value=None,
):
tracker_names = {"torchx.tracker.mlflow.create_tracker": "myconfig.txt"}
trackers = build_trackers(tracker_names)
self.assertEqual(1, len(list(trackers)))
self.assertIsInstance(trackers[0], MLflowTracker)

def test_build_trackers(self) -> None:
with patch(
"torchx.tracker.api.load_group",
Expand Down

0 comments on commit 9c90c3e

Please sign in to comment.