From b54df67d978631e813b6c98c52ba2f5d058a241b Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Wed, 24 Jan 2024 14:26:27 -0500 Subject: [PATCH] adding module lookup for building trackers (#803) Co-authored-by: Alexander Jipa --- torchx/schedulers/kubernetes_scheduler.py | 2 +- torchx/tracker/__init__.py | 12 ++++++--- torchx/tracker/api.py | 25 ++++++++--------- torchx/tracker/test/api_test.py | 24 ++++++++++++++++- torchx/util/modules.py | 33 +++++++++++++++++++++++ torchx/util/test/modules_test.py | 23 ++++++++++++++++ 6 files changed, 99 insertions(+), 20 deletions(-) create mode 100644 torchx/util/modules.py create mode 100644 torchx/util/test/modules_test.py diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index eafcc55cd..c5a2c22bd 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -23,7 +23,7 @@ kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.6.0/installer/volcano-development.yaml See the -`Volcano Quickstart `_ +`Volcano Quickstart `_ for more information. """ diff --git a/torchx/tracker/__init__.py b/torchx/tracker/__init__.py index a25f96cd8..813055183 100644 --- a/torchx/tracker/__init__.py +++ b/torchx/tracker/__init__.py @@ -37,7 +37,7 @@ ------------- To enable tracking it requires: -1. Defining tracker backends (entrypoints and configuration) on launcher side using :doc:`runner.config` +1. Defining tracker backends (entrypoints/modules and configuration) on launcher side using :doc:`runner.config` 2. Adding entrypoints within a user job using entry_points (`specification`_) .. _specification: https://packaging.python.org/en/latest/specifications/entry-points/ @@ -49,13 +49,13 @@ User can define any number of tracker backends under **torchx:tracker** section in :doc:`runner.config`, where: * Key: is an arbitrary name for the tracker, where the name will be used to configure its properties under [tracker:] - * Value: is *entrypoint/factory method* that must be available within user job. The value will be injected into a + * Value: is *entrypoint* or *module* factory method that must be available within user job. The value will be injected into a user job and used to construct tracker implementation. .. code-block:: ini [torchx:tracker] - tracker_name= + tracker_name= Each tracker can be additionally configured (currently limited to `config` parameter) under `[tracker:]` section: @@ -71,11 +71,15 @@ [torchx:tracker] tracker1=tracker1 - tracker12=backend_2_entry_point + tracker2=backend_2_entry_point + tracker3=torchx.tracker.mlflow:create_tracker [tracker:tracker1] config=s3://my_bucket/config.json + [tracker:tracker3] + config=my_config.json + 2. User job configuration (Advanced) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/torchx/tracker/api.py b/torchx/tracker/api.py index 76c139166..e288ac883 100644 --- a/torchx/tracker/api.py +++ b/torchx/tracker/api.py @@ -14,6 +14,7 @@ from typing import Iterable, Mapping, Optional from torchx.util.entrypoints import load_group +from torchx.util.modules import load_module logger: logging.Logger = logging.getLogger(__name__) @@ -177,30 +178,26 @@ def _extract_tracker_name_and_config_from_environ() -> Mapping[str, Optional[str def build_trackers( - entrypoint_and_config: Mapping[str, Optional[str]] + factory_and_config: Mapping[str, Optional[str]] ) -> Iterable[TrackerBase]: trackers = [] - entrypoint_factories = load_group("torchx.tracker") + entrypoint_factories = load_group("torchx.tracker") or {} if not entrypoint_factories: - logger.warning( - "No 'torchx.tracker' entry_points are defined. Tracking will not capture any data." - ) - return trackers + logger.warning("No 'torchx.tracker' entry_points are defined.") - for entrypoint_key, config in entrypoint_and_config.items(): - if entrypoint_key not in entrypoint_factories: + for factory_name, config in factory_and_config.items(): + factory = entrypoint_factories.get(factory_name) or load_module(factory_name) + if not factory: logger.warning( - f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..." + f"No tracker factory `{factory_name}` found in entry_points or modules. See https://pytorch.org/torchx/main/tracker.html#module-torchx.tracker" ) continue - factory = entrypoint_factories[entrypoint_key] if config: - logger.info(f"Tracker config found for `{entrypoint_key}` as `{config}`") - tracker = factory(config) + logger.info(f"Tracker config found for `{factory_name}` as `{config}`") else: - logger.info(f"No tracker config specified for `{entrypoint_key}`") - tracker = factory(None) + logger.info(f"No tracker config specified for `{factory_name}`") + tracker = factory(config) trackers.append(tracker) return trackers diff --git a/torchx/tracker/test/api_test.py b/torchx/tracker/test/api_test.py index 65f944087..b17c4c73e 100644 --- a/torchx/tracker/test/api_test.py +++ b/torchx/tracker/test/api_test.py @@ -9,7 +9,7 @@ from collections import defaultdict from typing import cast, DefaultDict, Dict, Iterable, Mapping, Optional, Tuple from unittest import mock, TestCase -from unittest.mock import patch +from unittest.mock import MagicMock, patch from torchx.tracker import app_run_from_env from torchx.tracker.api import ( @@ -27,6 +27,8 @@ TrackerSource, ) +from torchx.tracker.mlflow import MLflowTracker + RunId = str DEFAULT_SOURCE: str = "__parent__" @@ -271,6 +273,26 @@ def test_build_trackers_with_no_entrypoints_group_defined(self) -> None: trackers = build_trackers(tracker_names) self.assertEqual(0, len(list(trackers))) + def test_build_trackers_with_module(self) -> None: + module = MagicMock() + module.return_value = MagicMock(spec=MLflowTracker) + with patch( + "torchx.tracker.api.load_group", + return_value=None, + ) and patch( + "torchx.tracker.api.load_module", + return_value=module, + ): + tracker_names = { + "torchx.tracker.mlflow:create_tracker": (config := "myconfig.txt") + } + trackers = build_trackers(tracker_names) + trackers = list(trackers) + self.assertEqual(1, len(trackers)) + tracker = trackers[0] + self.assertIsInstance(tracker, MLflowTracker) + module.assert_called_once_with(config) + def test_build_trackers(self) -> None: with patch( "torchx.tracker.api.load_group", diff --git a/torchx/util/modules.py b/torchx/util/modules.py new file mode 100644 index 000000000..6447db687 --- /dev/null +++ b/torchx/util/modules.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +from types import ModuleType +from typing import Callable, Optional, Union + + +def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., object]]]: + """ + Loads and returns the module/module attr represented by the ``path``: ``full.module.path:optional_attr`` + + :: + + + 1. ``load_module("this.is.a_module:fn")`` -> equivalent to ``this.is.a_module.fn`` + 1. ``load_module("this.is.a_module")`` -> equivalent to ``this.is.a_module`` + """ + parts = path.split(":", 2) + module_path, method = parts[0], parts[1] if len(parts) > 1 else None + module = None + i, n = -1, len(module_path) + try: + while i < n: + i = module_path.find(".", i + 1) + i = i if i >= 0 else n + module = importlib.import_module(module_path[:i]) + return getattr(module, method) if method else module + except Exception: + return None diff --git a/torchx/util/test/modules_test.py b/torchx/util/test/modules_test.py new file mode 100644 index 000000000..7b490fc70 --- /dev/null +++ b/torchx/util/test/modules_test.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from torchx.util.modules import load_module + + +class ModulesTest(unittest.TestCase): + def test_load_module(self) -> None: + result = load_module("os.path") + import os + + self.assertEqual(result, os.path) + + def test_load_module_method(self) -> None: + result = load_module("os.path:join") + import os + + self.assertEqual(result, os.path.join)