Skip to content

Commit

Permalink
Merge pull request #339 from lsst/tickets/DM-38952
Browse files Browse the repository at this point in the history
DM-38952: revive pipeline mock system and move it here (from ctrl_mpexec)
  • Loading branch information
TallJimbo authored Jun 9, 2023
2 parents 9c4b09d + 7dcc699 commit b6be565
Show file tree
Hide file tree
Showing 14 changed files with 1,112 additions and 14 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-38952.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Revived bitrotted support for "mocked" `PipelineTask` execution and moved it here (from `ctrl_mpexec`).
3 changes: 3 additions & 0 deletions doc/lsst.pipe.base/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Developing Pipelines
:maxdepth: 1

creating-a-pipeline.rst
testing-pipelines-with-mocks.rst

.. _lsst.pipe.base-contributing:

Expand Down Expand Up @@ -85,3 +86,5 @@ Python API reference

.. automodapi:: lsst.pipe.base.pipelineIR
:no-main-docstr:

.. automodapi:: lsst.pipe.base.tests.mocks
25 changes: 25 additions & 0 deletions doc/lsst.pipe.base/testing-pipelines-with-mocks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
.. py:currentmodule:: lsst.pipe.base.tests.mocks
.. _testing-pipelines-with-mocks:

############################
Testing pipelines with mocks
############################

The `lsst.pipe.base.tests.mocks` package provides a way to build and execute `.QuantumGraph` objects without actually running any real task code or relying on real data.
This is primarily for testing the middleware responsible for `.QuantumGraph` generation and execution, but it can also be used to check that the connections in a configured pipeline are consistent with each other and with any documented recommendations for how to run those steps (e.g., which dimensions can safely be constrained by user expressions).

The high-level entry point to this system is `mock_task_defs` function, which takes an iterable of `.TaskDef` objects (typically obtained from `.Pipeline.toExpandedPipeline`) and returns a new sequence of `.TaskDef` objects, in which each original task has been replaced by a configuration of `MockPipelineTask` whose connections are analogous to the original.
Passing the ``--mock`` option to ``pipetask qgraph`` or ``pipetask run`` will run this on the given pipeline when building the graph.
When a pipeline is mocked, all task labels and dataset types are transformed by the `get_mock_name` function (so these can live alongside their real counterparts in a single data repository), and the storage classes of all regular connections are replaced with instances of `MockStorageClass`.
The in-memory Python type for `MockStorageClass` is always `MockDataset`, which is always written to disk in JSON format, but conversions between mock storage classes are always defined analogously to the original storage classes they mock, and the `MockDataset` class records conversions (and component access and parameters) when they occur, allowing test code that runs later to load them and inspect exactly how the object was loaded and provided to the task when it was executed.

The `MockPipelineTask.runQuantum` method reads all input mocked datasets that correspond to a `MockStorageClass` and simulates reading any input datasets there were not mocked (via the `MockPipelineTaskConfig.unmocked_dataset_types` config option, or the `mock_task_defs` argument of the same name) by constructing a new `MockDataset` instance for them.
It then constructs and writes new `MockDataset` instances for each of its predicted outputs, storing copies of the input `MockDataset`\s within them.
`MockPipelineTaskConfig` and `mock_task_defs` also have options for causing quanta that match a data ID expression to raise an exception instead.
Dataset types produced by the execution framework - configs, logs, metadata, and package version information - are not mocked, but they are given names with the prefix added by `get_mock_name` by virtue of being constructed from a task label that has that prefix.

Importing the `lsst.pipe.base.tests.mocks` package causes the `~lsst.daf.butler.StorageClassFactory` and `~lsst.daf.butler.FormatterFactory` classes to be monkey-patched with special code that recognizes mock storage class names without being included in any butler configuration files.
This should not affect how any non-mock storage classes are handled, but it is still best to only import `lsst.pipe.base.tests.mocks` in code that is *definitely* using the mock system, even if that means putting the import at function scope instead of module scope.

The `ci_middleware <https://github.com/lsst/ci_middleware.git>`_ package is the primary place where this mocking library is used, and the home of its unit tests, but it has been designed to be usable in regular "real" data repositories as well.
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ plugins = pydantic.mypy
[mypy-networkx.*]
ignore_missing_imports = True

# astropy doesn't ship type annotations
[mypy-astropy.*]
ignore_missing_imports = True

# Don't check LSST packages generally or even try to import them, since most
# don't have type annotations.

Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"lsst-utils",
"lsst-daf-butler",
"lsst-pex-config",
"astropy",
"pydantic",
"networkx",
"pyyaml >= 5.1",
Expand Down Expand Up @@ -114,3 +115,13 @@ addopts = "--flake8"
flake8-ignore = ["E203", "W503", "N802", "N803", "N806", "N812", "N815", "N816"]
# Some unit tests open registry database and don't close it.
open_files_ignore = ["gen3.sqlite3"]

[tool.coverage.report]
show_missing = true
exclude_lines = [
"pragma: no cover",
"raise AssertionError",
"raise NotImplementedError",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
5 changes: 4 additions & 1 deletion python/lsst/pipe/base/_dataset_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
__all__ = ["InMemoryDatasetHandle"]

import dataclasses
from typing import Any, Optional
from typing import Any, Optional, cast

from frozendict import frozendict
from lsst.daf.butler import (
Expand Down Expand Up @@ -175,6 +175,9 @@ class can be found.
# Parameters for derived components are applied against the
# composite.
if component in thisStorageClass.derivedComponents:
# For some reason MyPy doesn't see the line above as narrowing
# 'component' from 'str | None' to 'str'.
component = cast(str, component)
thisStorageClass.validateParameters(parameters)

# Process the parameters (hoping this never modified the
Expand Down
18 changes: 6 additions & 12 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from lsst.daf.butler.registry import MissingCollectionError, MissingDatasetTypeError
from lsst.daf.butler.registry.queries import DataCoordinateQueryResults
from lsst.daf.butler.registry.wildcards import CollectionWildcard
from lsst.utils import doImportType

# -----------------------------
# Imports for other modules --
Expand Down Expand Up @@ -1557,6 +1556,7 @@ def makeGraph(
datasetQueryConstraint: DatasetQueryConstraintVariant = DatasetQueryConstraintVariant.ALL,
metadata: Optional[Mapping[str, Any]] = None,
bind: Optional[Mapping[str, Any]] = None,
dataId: DataCoordinate | None = None,
) -> QuantumGraph:
"""Create execution graph for a pipeline.
Expand Down Expand Up @@ -1585,6 +1585,8 @@ def makeGraph(
bind : `Mapping`, optional
Mapping containing literal values that should be injected into the
``userQuery`` expression, keyed by the identifiers they replace.
dataId : `lsst.daf.butler.DataCoordinate`, optional
Data ID that should also be included in the query constraint.
Returns
-------
Expand All @@ -1603,18 +1605,10 @@ def makeGraph(
scaffolding = _PipelineScaffolding(pipeline, registry=self.registry)
if not collections and (scaffolding.initInputs or scaffolding.inputs or scaffolding.prerequisites):
raise ValueError("Pipeline requires input datasets but no input collections provided.")
instrument_class: Optional[Any] = None
if isinstance(pipeline, Pipeline):
instrument_class_name = pipeline.getInstrument()
if instrument_class_name is not None:
instrument_class = doImportType(instrument_class_name)
pipeline = list(pipeline.toExpandedPipeline())
if instrument_class is not None:
dataId = DataCoordinate.standardize(
instrument=instrument_class.getName(), universe=self.registry.dimensions
)
else:
if dataId is None:
dataId = DataCoordinate.makeEmpty(self.registry.dimensions)
if isinstance(pipeline, Pipeline):
dataId = pipeline.get_data_id(self.registry.dimensions).union(dataId)
with scaffolding.connectDataIds(
self.registry, collections, userQuery, dataId, datasetQueryConstraint, bind
) as commonDataIds:
Expand Down
31 changes: 30 additions & 1 deletion python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@

# -----------------------------
# Imports for other modules --
from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
from lsst.daf.butler import (
DataCoordinate,
DatasetType,
DimensionUniverse,
NamedValueSet,
Registry,
SkyPixDimension,
)
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils import doImportType
from lsst.utils.introspection import get_full_type_name
Expand Down Expand Up @@ -613,6 +620,28 @@ def getInstrument(self) -> Optional[str]:
"""
return self._pipelineIR.instrument

def get_data_id(self, universe: DimensionUniverse) -> DataCoordinate:
"""Return a data ID with all dimension constraints embedded in the
pipeline.
Parameters
----------
universe : `lsst.daf.butler.DimensionUniverse`
Object that defines all dimensions.
Returns
-------
data_id : `lsst.daf.butler.DataCoordinate`
Data ID with all dimension constraints embedded in the
pipeline.
"""
instrument_class_name = self._pipelineIR.instrument
if instrument_class_name is not None:
instrument_class = doImportType(instrument_class_name)
if instrument_class is not None:
return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe)
return DataCoordinate.makeEmpty(universe)

def addTask(self, task: Union[Type[PipelineTask], str], label: str) -> None:
"""Add a new task to the pipeline, or replace a task that is already
associated with the supplied label.
Expand Down
30 changes: 30 additions & 0 deletions python/lsst/pipe/base/tests/mocks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""A system for replacing the tasks in a pipeline with mocks that just read and
write trivial datasets.
See :ref:`testing-pipelines-with-mocks` for details.
"""

from ._data_id_match import *
from ._pipeline_task import *
from ._storage_class import *
173 changes: 173 additions & 0 deletions python/lsst/pipe/base/tests/mocks/_data_id_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

__all__ = ["DataIdMatch"]

import operator
from collections.abc import Callable
from typing import Any

import astropy.time
from lsst.daf.butler import DataId
from lsst.daf.butler.registry.queries.expressions.parser import Node, ParserYacc, TreeVisitor # type: ignore


class _DataIdMatchTreeVisitor(TreeVisitor):
"""Expression tree visitor which evaluates expression using values from
DataId.
"""

def __init__(self, dataId: DataId):
self.dataId = dataId

def visitNumericLiteral(self, value: str, node: Node) -> Any:
# docstring is inherited from base class
try:
return int(value)
except ValueError:
return float(value)

def visitStringLiteral(self, value: str, node: Node) -> Any:
# docstring is inherited from base class
return value

def visitTimeLiteral(self, value: astropy.time.Time, node: Node) -> Any:
# docstring is inherited from base class
return value

def visitRangeLiteral(self, start: int, stop: int, stride: int | None, node: Node) -> Any:
# docstring is inherited from base class
if stride is None:
return range(start, stop + 1)
else:
return range(start, stop + 1, stride)

def visitIdentifier(self, name: str, node: Node) -> Any:
# docstring is inherited from base class
return self.dataId[name]

def visitUnaryOp(self, operator_name: str, operand: Any, node: Node) -> Any:
# docstring is inherited from base class
operators: dict[str, Callable[[Any], Any]] = {
"NOT": operator.not_,
"+": operator.pos,
"-": operator.neg,
}
return operators[operator_name](operand)

def visitBinaryOp(self, operator_name: str, lhs: Any, rhs: Any, node: Node) -> Any:
# docstring is inherited from base class
operators = {
"OR": operator.or_,
"AND": operator.and_,
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
"%": operator.mod,
"=": operator.eq,
"!=": operator.ne,
"<": operator.lt,
">": operator.gt,
"<=": operator.le,
">=": operator.ge,
}
return operators[operator_name](lhs, rhs)

def visitIsIn(self, lhs: Any, values: list[Any], not_in: bool, node: Node) -> Any:
# docstring is inherited from base class
is_in = True
for value in values:
if not isinstance(value, range):
value = [value]
if lhs in value:
break
else:
is_in = False
if not_in:
is_in = not is_in
return is_in

def visitParens(self, expression: Any, node: Node) -> Any:
# docstring is inherited from base class
return expression

def visitTupleNode(self, items: tuple[Any, ...], node: Node) -> Any:
# docstring is inherited from base class
raise NotImplementedError()

def visitFunctionCall(self, name: str, args: list[Any], node: Node) -> Any:
# docstring is inherited from base class
raise NotImplementedError()

def visitPointNode(self, ra: Any, dec: Any, node: Node) -> Any:
# docstring is inherited from base class
raise NotImplementedError()


class DataIdMatch:
"""Class that can match DataId against the user-defined string expression.
Parameters
----------
expression : `str`
User-defined expression, supports syntax defined by daf_butler
expression parser. Maps identifiers in the expression to the values of
DataId.
"""

def __init__(self, expression: str):
parser = ParserYacc()
self.expression = expression
self.tree = parser.parse(expression)

def match(self, dataId: DataId) -> bool:
"""Matches DataId contents against the expression.
Parameters
----------
dataId : `DataId`
DataId that is matched against an expression.
Returns
-------
match : `bool`
Result of expression evaluation.
Raises
------
KeyError
Raised when identifier in expression is not defined for given
`DataId`.
TypeError
Raised when expression evaluates to a non-boolean type or when
operation in expression cannot be performed on operand types.
NotImplementedError
Raised when expression includes valid but unsupported syntax, e.g.
function call.
"""
visitor = _DataIdMatchTreeVisitor(dataId)
result = self.tree.visit(visitor)
if not isinstance(result, bool):
raise TypeError(f"Expression '{self.expression}' returned non-boolean object {type(result)}")
return result
Loading

0 comments on commit b6be565

Please sign in to comment.