Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-39902: add support for deprecating connections and connection templates #351

Merged
merged 4 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-39902
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make it possible to deprecate `PipelineTask` connections.
7 changes: 5 additions & 2 deletions python/lsst/pipe/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,18 @@ def __new__(
configConnectionsNamespace: dict[str, pexConfig.Field] = {}
for fieldName, obj in connectionsClass.allConnections.items():
configConnectionsNamespace[fieldName] = pexConfig.Field[str](
doc=f"name for connection {fieldName}", default=obj.name
doc=f"name for connection {fieldName}", default=obj.name, deprecated=obj.deprecated
)
# If there are default templates also add them as fields to
# configure the template values
if hasattr(connectionsClass, "defaultTemplates"):
docString = "Template parameter used to format corresponding field template parameter"
for templateName, default in connectionsClass.defaultTemplates.items():
configConnectionsNamespace[templateName] = TemplateField(
dtype=str, doc=docString, default=default
dtype=str,
doc=docString,
default=default,
deprecated=connectionsClass.deprecatedTemplates.get(templateName),
)
# add a reference to the connection class used to create this sub
# config
Expand Down
7 changes: 7 additions & 0 deletions python/lsst/pipe/base/connectionTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,19 @@ class BaseConnection:
consistent (i.e. zip-iterable) in `PipelineTask.runQuantum()` and
notify the execution system as early as possible of outputs that will
not be produced because the corresponding input is missing.
deprecated : `str`, optional
A description of why this connection is deprecated, including the
version after which it may be removed.

If not `None`, the string is appended to the docstring for this
connection and the corresponding config Field.
"""

name: str
storageClass: str
doc: str = ""
multiple: bool = False
deprecated: str | None = dataclasses.field(default=None, kw_only=True)

_connection_type_set: ClassVar[str]

Expand Down
27 changes: 25 additions & 2 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@
import dataclasses
import itertools
import string
import warnings
from collections import UserDict
from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
from dataclasses import dataclass
from types import MappingProxyType, SimpleNamespace
from typing import TYPE_CHECKING, Any

from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, NamedKeyDict, NamedKeyMapping, Quantum
from lsst.utils.introspection import find_outside_stacklevel

from ._status import NoWorkFound
from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInput
Expand Down Expand Up @@ -222,14 +224,20 @@ def __new__(cls, name, bases, dct, **kwargs):
# look up any template from base classes and merge them all
# together
mergeDict = {}
mergeDeprecationsDict = {}
for base in bases[::-1]:
if hasattr(base, "defaultTemplates"):
mergeDict.update(base.defaultTemplates)
if hasattr(base, "deprecatedTemplates"):
mergeDeprecationsDict.update(base.deprecatedTemplates)
if "defaultTemplates" in kwargs:
mergeDict.update(kwargs["defaultTemplates"])

if "deprecatedTemplates" in kwargs:
mergeDeprecationsDict.update(kwargs["deprecatedTemplates"])
if len(mergeDict) > 0:
kwargs["defaultTemplates"] = mergeDict
if len(mergeDeprecationsDict) > 0:
kwargs["deprecatedTemplates"] = mergeDeprecationsDict

# Verify that if templated strings were used, defaults were
# supplied as an argument in the declaration of the connection
Expand All @@ -256,6 +264,7 @@ def __new__(cls, name, bases, dct, **kwargs):
f" (conflicts are {nameTemplateIntersection})."
)
dct["defaultTemplates"] = kwargs.get("defaultTemplates", {})
dct["deprecatedTemplates"] = kwargs.get("deprecatedTemplates", {})

# Convert all the connection containers into frozensets so they cannot
# be modified at the class scope
Expand Down Expand Up @@ -317,7 +326,15 @@ def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskCo
instance.allConnections = MappingProxyType(instance._allConnections)
for internal_name, connection in cls.allConnections.items():
dataset_type_name = getattr(config.connections, internal_name).format(**templateValues)
instance_connection = dataclasses.replace(connection, name=dataset_type_name)
instance_connection = dataclasses.replace(
connection,
name=dataset_type_name,
doc=(
connection.doc
if connection.deprecated is None
else f"{connection.doc}\n{connection.deprecated}"
),
)
instance._allConnections[internal_name] = instance_connection

# Finally call __init__. The base class implementation does nothing;
Expand Down Expand Up @@ -352,6 +369,12 @@ def __call__(cls, *, config: PipelineTaskConfig | None = None) -> PipelineTaskCo
instance._allConnections.clear()
instance._allConnections.update(updated_all_connections)

for obj in instance._allConnections.values():
if obj.deprecated is not None:
warnings.warn(
obj.deprecated, FutureWarning, stacklevel=find_outside_stacklevel("lsst.pipe.base")
)

# Freeze the connection instance dimensions now. This at odds with the
# type annotation, which says [mutable] `set`, just like the connection
# type attributes (e.g. `inputs`, `outputs`, etc.), though MyPy can't
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/graph/_implDetails.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def _pruner(
taskClass=node.quantum.taskClass,
dataId=node.quantum.dataId,
initInputs=node.quantum.initInputs,
inputs=helper.inputs, # type: ignore
outputs=helper.outputs, # type: ignore
inputs=helper.inputs,
outputs=helper.outputs,
)
# If the inputs or outputs were adjusted to something different
# than what was supplied by the graph builder, dissassociate
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ def makeQuantum(self, datastore_records: Mapping[str, DatastoreRecordData] | Non
taskClass=self.task.taskDef.taskClass,
dataId=self.dataId,
initInputs=initInputs,
inputs=helper.inputs, # type: ignore
outputs=helper.outputs, # type: ignore
inputs=helper.inputs,
outputs=helper.outputs,
datastore_records=quantum_records,
)

Expand Down
40 changes: 40 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
"""

import unittest
import warnings

import lsst.pipe.base as pipeBase
import lsst.utils.tests
import pytest
from lsst.pex.config import Field


class TestConnectionsClass(unittest.TestCase):
Expand Down Expand Up @@ -182,6 +184,44 @@ class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=
with self.assertRaises(TypeError):
pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output", storageClass="mock")

def test_deprecation(self) -> None:
"""Test support for deprecating connections."""

class TestConnections(
pipeBase.PipelineTaskConnections,
dimensions=self.test_dims,
defaultTemplates={"t1": "dataset_type_1"},
deprecatedTemplates={"t1": "Deprecated in v600, will be removed after v601."},
):
input1 = pipeBase.connectionTypes.Input(
doc="Docs for input1",
name="input1_{t1}",
storageClass="StructuredDataDict",
deprecated="Deprecated in v50000, will be removed after v50001.",
)

def __init__(self, config):
if config.drop_input1:
del self.input1

class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnections):
drop_input1 = Field("Remove the 'input1' connection if True", dtype=bool, default=False)

config = TestConfig()
with self.assertWarns(FutureWarning):
config.connections.input1 = "dataset_type_2"
with self.assertWarns(FutureWarning):
config.connections.t1 = "dataset_type_3"

with self.assertWarns(FutureWarning):
TestConnections(config=config)

config.drop_input1 = True

with warnings.catch_warnings():
warnings.simplefilter("error", FutureWarning)
TestConnections(config=config)


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
pass
Expand Down