diff --git a/doc/changes/DM-40198.feature.md b/doc/changes/DM-40198.feature.md new file mode 100644 index 000000000..8433de7d9 --- /dev/null +++ b/doc/changes/DM-40198.feature.md @@ -0,0 +1 @@ +Parameters defined in a Pipeline can now be used within a config python block as well as within config files loaded by a Pipeline. diff --git a/python/lsst/pipe/base/config.py b/python/lsst/pipe/base/config.py index bc984ee0a..ca4ce1423 100644 --- a/python/lsst/pipe/base/config.py +++ b/python/lsst/pipe/base/config.py @@ -241,6 +241,7 @@ def applyConfigOverrides( The label associated with this class's Task in a pipeline. """ overrides = ConfigOverrides() + overrides.addParameters(parameters) if instrument is not None: overrides.addInstrumentOverride(instrument, taskDefaultName) if pipelineConfigs is not None: diff --git a/python/lsst/pipe/base/configOverrides.py b/python/lsst/pipe/base/configOverrides.py index 3c49ed2b0..4fe2d16d9 100644 --- a/python/lsst/pipe/base/configOverrides.py +++ b/python/lsst/pipe/base/configOverrides.py @@ -21,6 +21,7 @@ """Module which defines ConfigOverrides class and related methods. """ +from __future__ import annotations __all__ = ["ConfigOverrides"] @@ -28,14 +29,34 @@ import inspect from enum import Enum from operator import attrgetter +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any from lsst.resources import ResourcePath from ._instrument import Instrument +if TYPE_CHECKING: + from .pipelineIR import ParametersIR + OverrideTypes = Enum("OverrideTypes", "Value File Python Instrument") +class _FrozenSimpleNamespace(SimpleNamespace): + """SimpleNamespace subclass which disallows setting after construction""" + + def __init__(self, **kwargs: Any) -> None: + object.__setattr__(self, "_frozen", False) + super().__init__(**kwargs) + self._frozen = True + + def __setattr__(self, __name: str, __value: Any) -> None: + if self._frozen: + raise ValueError("Cannot set attributes on parameters") + else: + return super().__setattr__(__name, __value) + + class ConfigExpressionParser(ast.NodeVisitor): """An expression parser that will be used to transform configuration strings supplied from the command line or a pipeline into a python @@ -139,8 +160,28 @@ class ConfigOverrides: necessary. """ - def __init__(self): - self._overrides = [] + def __init__(self) -> None: + self._overrides: list[tuple[OverrideTypes, Any]] = [] + self._parameters: SimpleNamespace | None = None + + def addParameters(self, parameters: ParametersIR) -> None: + """Add parameters which will be substituted when applying overrides. + + Parameters + ---------- + parameters : `ParametersIR` + Override parameters in the form as read from a Pipeline file. + + Note + ---- + This method may be called more than once, but each call will overwrite + any previous parameter defined with the same name. + """ + if self._parameters is None: + self._parameters = SimpleNamespace() + + for key, value in parameters.mapping.items(): + setattr(self._parameters, key, value) def addFileOverride(self, filename): """Add overrides from a specified file. @@ -222,20 +263,28 @@ def applyTo(self, config): # Look up a stack of variables people may be using when setting # configs. Create a dictionary that will be used akin to a namespace # for the duration of this function. - vars = {} + localVars = {} # pull in the variables that are declared in module scope of the config mod = inspect.getmodule(config) - vars.update({k: v for k, v in mod.__dict__.items() if not k.startswith("__")}) + localVars.update({k: v for k, v in mod.__dict__.items() if not k.startswith("__")}) # put the supplied config in the variables dictionary - vars["config"] = config + localVars["config"] = config + extraLocals = None + + # If any parameters are supplied add them to the variables dictionary + if self._parameters is not None: + # make a copy of the params and "freeze" it + localParams = _FrozenSimpleNamespace(**vars(self._parameters)) + localVars["parameters"] = localParams + extraLocals = {"parameters": localParams} # Create a parser for config expressions that may be strings - configParser = ConfigExpressionParser(namespace=vars) + configParser = ConfigExpressionParser(namespace=localVars) for otype, override in self._overrides: if otype is OverrideTypes.File: with override.open("r") as buffer: - config.loadFromStream(buffer, filename=override.ospath) + config.loadFromStream(buffer, filename=override.ospath, extraLocals=extraLocals) elif otype is OverrideTypes.Value: field, value = override if isinstance(value, str): @@ -289,7 +338,7 @@ def applyTo(self, config): # in a python block will be put into this scope. This means # other config setting branches can make use of these # variables. - exec(override, None, vars) + exec(override, None, localVars) elif otype is OverrideTypes.Instrument: instrument, name = override instrument.applyConfigOverrides(name, config) diff --git a/python/lsst/pipe/base/pipelineIR.py b/python/lsst/pipe/base/pipelineIR.py index 6bf9c080c..f8a57dd81 100644 --- a/python/lsst/pipe/base/pipelineIR.py +++ b/python/lsst/pipe/base/pipelineIR.py @@ -227,7 +227,7 @@ class ParametersIR: field2: parameters.shared_value """ - mapping: MutableMapping[str, str] + mapping: MutableMapping[str, Any] """A mutable mapping of identifiers as keys, and shared configuration as values. """ diff --git a/tests/test_configOverrides.py b/tests/test_configOverrides.py index e4c0054e5..b9174cecb 100644 --- a/tests/test_configOverrides.py +++ b/tests/test_configOverrides.py @@ -22,11 +22,13 @@ """Simple unit test for configOverrides. """ +import tempfile import unittest import lsst.pex.config as pexConfig import lsst.utils.tests from lsst.pipe.base.configOverrides import ConfigOverrides +from lsst.pipe.base.pipelineIR import ParametersIR # This is used in testSettingVar unit test TEST_CHOICE_VALUE = 1 @@ -266,6 +268,27 @@ def testDictValueInt(self): with self.assertRaises(pexConfig.FieldValidationError): self.checkSingleFieldOverride(field, {"a": "b"}) + def testConfigParameters(self): + """Test that passing parameters works""" + config = ConfigTest() + parameters = ParametersIR(mapping={"number": 6, "text": "hello world"}) + overrides = ConfigOverrides() + overrides.addParameters(parameters) + overrides.addPythonOverride("config.fStr = parameters.text") + with tempfile.NamedTemporaryFile(mode="w") as fileOverride: + fileOverride.write("config.fInt = parameters.number") + fileOverride.seek(0) + overrides.addFileOverride(fileOverride.name) + overrides.applyTo(config) + self.assertEqual(config.fStr, parameters.mapping["text"]) + self.assertEqual(config.fInt, parameters.mapping["number"]) + + overrides = ConfigOverrides() + overrides.addParameters(parameters) + overrides.addPythonOverride("parameters.fail = 9") + with self.assertRaises(ValueError): + overrides.applyTo(config) + class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): """Check for memory leaks."""