Skip to content

Commit

Permalink
SQUASH: more tests and minor fixes for PipelineGraph.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Jun 22, 2023
1 parent 532bbcb commit b9b2dc2
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 123 deletions.
26 changes: 26 additions & 0 deletions python/lsst/pipe/base/pipeline_graph/_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class Edge(ABC):
Name of the dataset type's storage class as seen by the task.
connection_name : `str`
Internal name for the connection as seen by the task.
raw_dimensions : `frozenset` [ `str` ]
Raw dimensions from the connection definition.
"""

def __init__(
Expand All @@ -63,11 +65,13 @@ def __init__(
dataset_type_key: NodeKey,
storage_class_name: str,
connection_name: str,
raw_dimensions: frozenset[str],
):
self.task_key = task_key
self.dataset_type_key = dataset_type_key
self.connection_name = connection_name
self.storage_class_name = storage_class_name
self.raw_dimensions = raw_dimensions

INIT_TO_TASK_NAME: ClassVar[str] = "INIT"
"""Edge key for the special edge that connects a task init node to the
Expand All @@ -90,6 +94,15 @@ def __init__(
not the parent storage class.
"""

raw_dimensions: frozenset[str]
"""Raw dimensions in the task declaration.
This can only be used safely for partial comparisons: two edges with the
same ``raw_dimensions`` (and the same parent dataset type name) always have
the same resolved dimensions, but edges with different ``raw_dimensions``
may also have the same resolvd dimensions.
"""

@property
def is_init(self) -> bool:
"""Whether this dataset is read or written when the task is
Expand Down Expand Up @@ -185,6 +198,13 @@ def diff(self: _S, other: _S, connection_type: str = "connection") -> list[str]:
f"{connection_type.capitalize()} {self.connection_name!r} has storage class "
f"{self.storage_class_name!r} in A, but {other.storage_class_name!r} in B."
)
if self.raw_dimensions != other.raw_dimensions:
result.append(
f"{connection_type.capitalize()} {self.connection_name!r} has raw dimensions "
f"{set(self.raw_dimensions)} in A, but {set(other.raw_dimensions)} in B "
"(differences in raw dimensions may not lead to differences in resolved dimensions, "
"but this cannot be checked without re-resolving the dataset type)."
)
return result

@abstractmethod
Expand Down Expand Up @@ -233,6 +253,8 @@ class ReadEdge(Edge):
Internal name for the connection as seen by the task.
component : `str` or `None`
Component of the dataset type requested by the task.
raw_dimensions : `frozenset` [ `str` ]
Raw dimensions from the connection definition.
Notes
-----
Expand Down Expand Up @@ -260,12 +282,14 @@ def __init__(
is_prerequisite: bool,
connection_name: str,
component: str | None,
raw_dimensions: frozenset[str],
):
super().__init__(
task_key=task_key,
dataset_type_key=dataset_type_key,
storage_class_name=storage_class_name,
connection_name=connection_name,
raw_dimensions=raw_dimensions,
)
self.is_prerequisite = is_prerequisite
self.component = component
Expand Down Expand Up @@ -351,6 +375,7 @@ def _from_connection_map(
storage_class_name=connection.storageClass,
is_prerequisite=is_prerequisite,
connection_name=connection_name,
raw_dimensions=frozenset(getattr(connection, "dimensions", frozenset())),
)

def _resolve_dataset_type(
Expand Down Expand Up @@ -585,6 +610,7 @@ def _from_connection_map(
dataset_type_key=NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name),
storage_class_name=connection.storageClass,
connection_name=connection_name,
raw_dimensions=frozenset(getattr(connection, "dimensions", frozenset())),
)

def _resolve_dataset_type(
Expand Down
88 changes: 41 additions & 47 deletions python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

__all__ = ("PipelineGraph",)

import copy
import gzip
import itertools
import json
Expand Down Expand Up @@ -51,7 +50,6 @@

if TYPE_CHECKING:
from ..config import PipelineTaskConfig
from ..configOverrides import ConfigOverrides
from ..connections import PipelineTaskConnections
from ..pipeline import TaskDef
from ..pipelineTask import PipelineTask
Expand Down Expand Up @@ -141,7 +139,6 @@ def _init_from_args(
`PipelineGraph` mutator methods provide strong exception safety (the
graph is left unchanged when an exception is raised and caught) unless
the exception raised is `PipelineGraphExceptionSafetyError`.
"""
self._xgraph = xgraph if xgraph is not None else networkx.MultiDiGraph()
self._sorted_keys: Sequence[NodeKey] | None
Expand Down Expand Up @@ -486,7 +483,7 @@ def _transform_xgraph_state(self, xgraph: _G, skip_edges: bool) -> _G:

def group_by_dimensions(
self, prerequisites: bool = False
) -> dict[DimensionGraph, tuple[list[TaskNode], list[DatasetTypeNode]]]:
) -> dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]]:
"""Group this graph's tasks and dataset types by their dimensions.
Parameters
Expand All @@ -501,8 +498,9 @@ def group_by_dimensions(
A dictionary of groups keyed by `DimensionGraph`, in which each
value is a tuple of:
- a `list` of `TaskNode` instances
- a `list` of `ResolvedDatasetTypeNode` instances
- a `dict` of `TaskNode` instances, keyed by task label
- a `dict` of `DatasetTypeNode` instances, keyed by
dataset type name.
that have those dimensions.
Expand All @@ -511,23 +509,23 @@ def group_by_dimensions(
Init inputs and outputs are always included, but always have empty
dimensions and are hence are all grouped together.
"""
result: dict[DimensionGraph, tuple[list[TaskNode], list[DatasetTypeNode]]] = {}
next_new_value: tuple[list[TaskNode], list[DatasetTypeNode]] = ([], [])
result: dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]] = {}
next_new_value: tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]] = ({}, {})
for task_label, task_node in self.tasks.items():
if task_node.dimensions is None:
raise UnresolvedGraphError(f"Task with label {task_label!r} has not been resolved.")
if (group := result.setdefault(task_node.dimensions, next_new_value)) is next_new_value:
next_new_value = ([], []) # make new lists for next time
group[0].append(task_node)
next_new_value = ({}, {}) # make new lists for next time
group[0][task_node.label] = task_node
for dataset_type_name, dataset_type_node in self.dataset_types.items():
if dataset_type_node is None:
raise UnresolvedGraphError(f"Dataset type {dataset_type_name!r} has not been resolved.")
if not dataset_type_node.is_prerequisite or prerequisites:
if (
group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value)
) is next_new_value:
next_new_value = ([], []) # make new lists for next time
group[1].append(dataset_type_node)
next_new_value = ({}, {}) # make new lists for next time
group[1][dataset_type_node.name] = dataset_type_node
return result

@property
Expand Down Expand Up @@ -570,7 +568,9 @@ def sort(self) -> None:
if self._sorted_keys is None:
try:
sorted_keys: Sequence[NodeKey] = list(networkx.lexicographical_topological_sort(self._xgraph))
except networkx.NetworkXUnfeasible as err:
except networkx.NetworkXUnfeasible as err: # pragma: no cover
# Should't be possible to get here, because we check for cycles
# when adding tasks, but we guard against it anyway.
cycle = networkx.find_cycle(self._xgraph)
raise PipelineDataCycleError(
f"Cycle detected while attempting to sort graph: {cycle}."
Expand Down Expand Up @@ -707,6 +707,7 @@ def add_task_nodes(self, nodes: Iterable[TaskNode]) -> None:
node_data: list[tuple[NodeKey, dict[str, Any]]] = []
edge_data: list[tuple[NodeKey, NodeKey, str, dict[str, Any]]] = []
for task_node in nodes:
task_node = task_node._resolved(self._universe)
node_data.append(
(task_node.key, {"instance": task_node, "bipartite": task_node.key.node_type.bipartite})
)
Expand Down Expand Up @@ -745,7 +746,9 @@ def add_task_nodes(self, nodes: Iterable[TaskNode]) -> None:
try:
self._xgraph.remove_edges_from(edge_data)
self._xgraph.remove_nodes_from(key for key, _ in node_data)
except Exception as err:
except Exception as err: # pragma: no cover
# There's no known way to get here, but we want to make it
# clear it's a big problem if we do.
raise PipelineGraphExceptionSafetyError(
"Error while attempting to revert PipelineGraph modification has left the graph in "
"an inconsistent state."
Expand All @@ -755,22 +758,20 @@ def add_task_nodes(self, nodes: Iterable[TaskNode]) -> None:

def reconfigure_tasks(
self,
*args: tuple[str, PipelineTaskConfig | ConfigOverrides],
*args: tuple[str, PipelineTaskConfig],
check_edges_unchanged: bool = False,
assume_edges_unchanged: bool = False,
**kwargs: PipelineTaskConfig | ConfigOverrides,
**kwargs: PipelineTaskConfig,
) -> None:
"""Update the configuration for one or more tasks.
Parameters
----------
*args : `tuple` [ `str`, `.PipelineTaskConfig` or `.ConfigOverrides` ]
Positional arguments are tuples of a task label and either a new
config object or a sequence of overrides to apply to a copy of the
current config. Note that the same arguments may also be passed as
``**kwargs``, which is usually more readable, but the same task may
be passed multiple times via ``*args`` and task labels in ``*args``
are not required to be valid Python identifiers.
*args : `tuple` [ `str`, `.PipelineTaskConfig` ]
Positional arguments are each a 2-tuple of task label and new
config object. Note that the same arguments may also be passed as
``**kwargs``, which is usually more readable, but task labels in
``*args`` are not required to be valid Python identifiers.
check_edges_unchanged : `bool`, optional
If `True`, require the edges (connections) of the modified tasks to
remain unchanged after the configuration updates, and verify that
Expand All @@ -779,42 +780,29 @@ def reconfigure_tasks(
If `True`, the caller declares that the edges (connections) of the
modified tasks will remain unchanged after the configuration
updates, and that it is unnecessary to check this.
**kwargs : `.PipelineTaskConfig` or `.ConfigOverrides`
**kwargs : `.PipelineTaskConfig`
New config objects or overrides to apply to copies of the current
config objects, with task labels as the keywords. Overrides
provided here are applied after those in ``*args``.
config objects, with task labels as the keywords.
Raises
------
ValueError
Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged``
are both `True`, or if a full config is provided for a task after
another full config or an override has already been provided.
are both `True`, or if the same task appears twice.
EdgesChangedError
Raised if ``check_edges_unchanged=True`` and the edges of a task do
change.
Notes
-----
If reconfiguring a task causes its edges to change, any dataset type
nodes linked to those edges will be reset to the unresolved state.
nodes connected to that task (not just those whose edges have changed!)
will be unresolved.
"""
from ..configOverrides import ConfigOverrides

new_configs: dict[str, PipelineTaskConfig] = {}
for task_label, config_update in itertools.chain(args, kwargs.items()):
if isinstance(config_update, ConfigOverrides):
if (new_config := new_configs.get(task_label)) is None:
new_config = copy.deepcopy(self.tasks[task_label].config)
new_configs[task_label] = new_config
config_update.applyTo(new_config)
else:
if new_configs.setdefault(task_label, config_update) is not config_update:
raise ValueError(
f"Full config for {task_label!r} provided in **kwargs after providing a "
"config or override in *args."
)

if new_configs.setdefault(task_label, config_update) is not config_update:
raise ValueError(f"Config for {task_label!r} provided more than once.")
updates = {
task_label: self.tasks[task_label]._reconfigured(
config, rebuild=not assume_edges_unchanged, universe=self._universe
Expand Down Expand Up @@ -914,7 +902,9 @@ def remove_tasks(
for subset_label in referencing_subsets:
self._task_subsets[subset_label].remove(task_node.label)
self._xgraph.remove_nodes_from(nodes_to_remove)
except Exception as err:
except Exception as err: # pragma: no cover
# There's no known way to get here, but we want to make it
# clear it's a big problem if we do.
raise PipelineGraphExceptionSafetyError(
"Error during task removal has left the graph in an inconsistent state."
) from err
Expand Down Expand Up @@ -1089,7 +1079,9 @@ def resolve(
try:
for node_key, node_value in updates.items():
self._xgraph.nodes[node_key]["instance"] = node_value
except Exception as err:
except Exception as err: # pragma: no cover
# There's no known way to get here, but we want to make it
# clear it's a big problem if we do.
raise PipelineGraphExceptionSafetyError(
"Error during dataset type resolution has left the graph in an inconsistent state."
) from err
Expand Down Expand Up @@ -1332,9 +1324,11 @@ def _replace_task_nodes(
for task_node in shallow.values():
self._xgraph.nodes[task_node.key]["instance"] = task_node
self._xgraph.nodes[task_node.init.key]["instance"] = task_node.init
except PipelineGraphExceptionSafetyError:
except PipelineGraphExceptionSafetyError: # pragma: no cover
raise
except Exception as err:
except Exception as err: # pragma: no cover
# There's no known way to get here, but we want to make it clear
# it's a big problem if we do.
raise PipelineGraphExceptionSafetyError(
"Error while replacing tasks has left the graph in an inconsistent state."
) from err
Expand Down
15 changes: 13 additions & 2 deletions python/lsst/pipe/base/pipeline_graph/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,11 @@ def config(self) -> PipelineTaskConfig:
"""
return self.init.config

@property
def has_dimensions(self) -> bool:
"""Whether the `dimensions` attribute my be accessed."""
return self._dimensions is not None

@property
def dimensions(self) -> DimensionGraph:
"""Standardized dimensions of the task."""
Expand Down Expand Up @@ -736,7 +741,7 @@ def _reconfigured(
),
)

def _resolved(self, universe: DimensionUniverse) -> TaskNode:
def _resolved(self, universe: DimensionUniverse | None) -> TaskNode:
"""Return an otherwise-equivalent task node with resolved dimensions.
Parameters
Expand All @@ -752,6 +757,8 @@ def _resolved(self, universe: DimensionUniverse) -> TaskNode:
"""
if self._dimensions is not None and self._dimensions.universe is universe:
return self
if self._dimensions is None and universe is None:
return self
return TaskNode(
key=self.key,
init=self.init,
Expand All @@ -760,7 +767,11 @@ def _resolved(self, universe: DimensionUniverse) -> TaskNode:
outputs=self.outputs,
log_output=self.log_output,
metadata_output=self.metadata_output,
dimensions=universe.extract(self._get_imported_data().connections.dimensions),
dimensions=(
universe.extract(self._get_imported_data().connections.dimensions)
if universe is not None
else None
),
)

def _to_xgraph_state(self) -> dict[str, Any]:
Expand Down
9 changes: 8 additions & 1 deletion python/lsst/pipe/base/pipeline_graph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,16 @@ class SerializedEdge(pydantic.BaseModel):
storage_class: str
"""Name of the storage class."""

raw_dimensions: list[str]
"""Raw dimensions of the dataset type from the task connections."""

@classmethod
def serialize(cls, target: Edge) -> SerializedEdge:
"""Transform an `Edge` to a `SerializedEdge`."""
return SerializedEdge.construct(
storage_class=target.storage_class_name, dataset_type_name=target.dataset_type_name
storage_class=target.storage_class_name,
dataset_type_name=target.dataset_type_name,
raw_dimensions=sorted(target.raw_dimensions),
)

def deserialize_read_edge(
Expand All @@ -115,6 +120,7 @@ def deserialize_read_edge(
is_prerequisite=is_prerequisite,
component=component,
connection_name=connection_name,
raw_dimensions=frozenset(self.raw_dimensions),
)

def deserialize_write_edge(
Expand All @@ -129,6 +135,7 @@ def deserialize_write_edge(
dataset_type_key=dataset_type_key,
storage_class_name=self.storage_class,
connection_name=connection_name,
raw_dimensions=frozenset(self.raw_dimensions),
)


Expand Down
Loading

0 comments on commit b9b2dc2

Please sign in to comment.