Skip to content

Commit

Permalink
Merge pull request #369 from materialsproject/flow-job-magic-methods
Browse files Browse the repository at this point in the history
`Flow` + `Job` magic methods
  • Loading branch information
utf authored Aug 13, 2023
2 parents 31b8954 + 781d3b5 commit 10aad48
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 28 deletions.
10 changes: 10 additions & 0 deletions docs/tutorials/8-fireworks.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ flow.update_config({"manager_config": {"_fworker": "fworker1"}}, name_filter="jo
flow.update_config({"manager_config": {"_fworker": "fworker2"}}, name_filter="job2")
```

NB: There are two ways to iterate over a `Flow`. The `iterflow` method iterates through a flow such that root nodes of the graph are always returned first. This has the benefit that the `job.output` references can always be resolved.
`Flow` also has an `__iter__` method, meaning you can write

```py
for job_or_subflow in flow:
...
```

to simply iterate through the `Flow.jobs` array. Note that `jobs` can also contain other flows.

### Launching the Jobs

As described above, convert the flow to a workflow via {obj}`flow_to_workflow` and add it to your launch pad.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ show_missing = true
exclude_lines = [
'^\s*@overload( |$)',
'^\s*assert False(,|$)',
'if TYPE_CHECKING:',
'if typing.TYPE_CHECKING:',
]

Expand Down
120 changes: 111 additions & 9 deletions src/jobflow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@

from __future__ import annotations

import copy
import logging
import typing
import warnings
from typing import TYPE_CHECKING, Sequence

from monty.json import MSONable

import jobflow
from jobflow.core.reference import find_and_get_references
from jobflow.utils import ValueEnum, contains_flow_or_job, suuid

if typing.TYPE_CHECKING:
from typing import Any, Callable
if TYPE_CHECKING:
from typing import Any, Callable, Iterator

from networkx import DiGraph

import jobflow
from jobflow import Job

__all__ = ["JobOrder", "Flow", "get_flow"]

Expand Down Expand Up @@ -144,8 +146,94 @@ def __init__(
self.add_jobs(jobs)
self.output = output

def __len__(self) -> int:
"""Get the number of jobs or subflows in the flow."""
return len(self.jobs)

def __getitem__(self, idx: int | slice) -> Flow | Job | tuple[Flow | Job, ...]:
"""Get the job(s) or subflow(s) at the given index/slice."""
return self.jobs[idx]

def __setitem__(
self, idx: int | slice, value: Flow | Job | Sequence[Flow | Job]
) -> None:
"""Set the job(s) or subflow(s) at the given index/slice."""
if (
not isinstance(value, (Flow, jobflow.Job, tuple, list))
or isinstance(value, (tuple, list))
and not all(isinstance(v, (Flow, jobflow.Job)) for v in value)
):
raise TypeError(
f"Flow can only contain Job or Flow objects, not {type(value).__name__}"
)
jobs = list(self.jobs)
jobs[idx] = value # type: ignore[index, assignment]
self.jobs = tuple(jobs)

def __iter__(self) -> Iterator[Flow | Job]:
"""Iterate through the jobs in the flow."""
return iter(self.jobs)

def __contains__(self, item: Flow | Job) -> bool:
"""Check if the flow contains a job or subflow."""
return item in self.jobs

def __add__(self, other: Job | Flow | Sequence[Flow | Job]) -> Flow:
"""Add a job or subflow to the flow."""
if not isinstance(other, (Flow, jobflow.Job, tuple, list)):
return NotImplemented
new_flow = self.__deepcopy__()
new_flow.add_jobs(other)
return new_flow

def __sub__(self, other: Flow | Job) -> Flow:
"""Remove a job or subflow from the flow."""
if other not in self.jobs:
raise ValueError(f"{other!r} not found in flow")
new_flow = self.__deepcopy__()
new_flow.jobs = tuple([job for job in new_flow.jobs if job != other])
return new_flow

def __repr__(self, level=0, index=None) -> str:
"""Get a string representation of the flow."""
indent = " " * level
name, uuid = self.name, self.uuid
flow_index = f"{index}." if index is not None else ""
job_reprs = "\n".join(
f"{indent}{flow_index}{i}. "
f"{j.__repr__(level + 1, f'{flow_index}{i}') if isinstance(j, Flow) else j}"
for i, j in enumerate(self.jobs, 1)
)
return f"Flow({name=}, {uuid=})\n{job_reprs}"

def __eq__(self, other: object) -> bool:
"""Check if the flow is equal to another flow."""
if not isinstance(other, Flow):
return NotImplemented
return self.uuid == other.uuid

def __hash__(self) -> int:
"""Get the hash of the flow."""
return hash(self.uuid)

def __deepcopy__(self, memo: dict[int, Any] = None) -> Flow:
"""Get a deep copy of the flow.
Shallow copy doesn't make sense; jobs aren't allowed to belong to multiple flows
"""
kwds = self.as_dict()
for key in ("jobs", "@class", "@module", "@version"):
kwds.pop(key)
jobs = copy.deepcopy(self.jobs, memo)
new_flow = Flow(jobs=[], **kwds)
# reassign host
for job in jobs:
job.hosts = [new_flow.uuid]
new_flow.jobs = jobs
return new_flow

@property
def jobs(self) -> tuple[Flow | jobflow.Job, ...]:
def jobs(self) -> tuple[Flow | Job, ...]:
"""
Get the Jobs in the Flow.
Expand All @@ -156,6 +244,20 @@ def jobs(self) -> tuple[Flow | jobflow.Job, ...]:
"""
return self._jobs

@jobs.setter
def jobs(self, jobs: Sequence[Flow | Job] | Job | Flow):
"""
Set the Jobs in the Flow.
Parameters
----------
jobs
The list of Jobs/Flows of the Flow.
"""
if isinstance(jobs, (Flow, jobflow.Job)):
jobs = [jobs]
self._jobs = tuple(jobs)

@property
def output(self) -> Any:
"""
Expand Down Expand Up @@ -666,7 +768,7 @@ def add_hosts_uuids(
for j in self.jobs:
j.add_hosts_uuids(hosts_uuids, prepend=prepend)

def add_jobs(self, jobs: list[Flow | jobflow.Job] | jobflow.Job | Flow):
def add_jobs(self, jobs: Job | Flow | Sequence[Flow | Job]) -> None:
"""
Add Jobs or Flows to the Flow.
Expand All @@ -679,14 +781,14 @@ def add_jobs(self, jobs: list[Flow | jobflow.Job] | jobflow.Job | Flow):
A list of Jobs and Flows.
"""
if not isinstance(jobs, (tuple, list)):
jobs = [jobs]
jobs = [jobs] # type: ignore[list-item]

job_ids = set(self.all_uuids)
hosts = [self.uuid, *self.hosts]
for job in jobs:
if job.host is not None and job.host != self.uuid:
raise ValueError(
f"{job.__class__.__name__} {job.name} ({job.uuid}) already belongs "
f"{type(job).__name__} {job.name} ({job.uuid}) already belongs "
f"to another flow."
)
if job.uuid in job_ids:
Expand Down Expand Up @@ -743,7 +845,7 @@ def remove_jobs(self, indices: int | list[int]):


def get_flow(
flow: Flow | jobflow.Job | list[jobflow.Job],
flow: Flow | Job | list[jobflow.Job],
) -> Flow:
"""
Check dependencies and return flow object.
Expand Down
56 changes: 48 additions & 8 deletions src/jobflow/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jobflow.utils.uuid import suuid

if typing.TYPE_CHECKING:
from typing import Any, Callable, Hashable
from typing import Any, Callable, Hashable, Sequence

from networkx import DiGraph
from pydantic import BaseModel
Expand Down Expand Up @@ -194,10 +194,7 @@ def get_job(*args, **kwargs) -> Job:
args = args[1:]

return Job(
function=f,
function_args=args,
function_kwargs=kwargs,
**job_kwargs,
function=f, function_args=args, function_kwargs=kwargs, **job_kwargs
)

get_job.original = func
Expand Down Expand Up @@ -366,6 +363,49 @@ def __init__(
f"inputs to your Job."
)

def __repr__(self):
"""Get a string representation of the job."""
name, uuid = self.name, self.uuid
return f"Job({name=}, {uuid=})"

def __contains__(self, item: Hashable) -> bool:
"""
Check if the job contains a reference to a given UUID.
Parameters
----------
item
A UUID.
Returns
-------
bool
Whether the job contains a reference to the UUID.
"""
return item in self.input_uuids

def __eq__(self, other: object) -> bool:
"""
Check if two jobs are equal.
Parameters
----------
other
Another job.
Returns
-------
bool
Whether the jobs are equal.
"""
if not isinstance(other, Job):
return NotImplemented
return self.__dict__ == other.__dict__

def __hash__(self) -> int:
"""Get the hash of the job."""
return hash(self.uuid)

@property
def input_references(self) -> tuple[jobflow.OutputReference, ...]:
"""
Expand Down Expand Up @@ -474,7 +514,7 @@ def host(self):
"""
return self.hosts[0] if self.hosts else None

def set_uuid(self, uuid: str):
def set_uuid(self, uuid: str) -> None:
"""
Set the UUID of the job.
Expand Down Expand Up @@ -1079,7 +1119,7 @@ def __setattr__(self, key, value):
else:
super().__setattr__(key, value)

def add_hosts_uuids(self, hosts_uuids: str | list[str], prepend: bool = False):
def add_hosts_uuids(self, hosts_uuids: str | Sequence[str], prepend: bool = False):
"""
Add a list of UUIDs to the internal list of hosts.
Expand All @@ -1095,7 +1135,7 @@ def add_hosts_uuids(self, hosts_uuids: str | list[str], prepend: bool = False):
Insert the UUIDs at the beginning of the list rather than extending it.
"""
if not isinstance(hosts_uuids, (list, tuple)):
hosts_uuids = [hosts_uuids]
hosts_uuids = [hosts_uuids] # type: ignore
if prepend:
self.hosts[0:0] = hosts_uuids
else:
Expand Down
2 changes: 1 addition & 1 deletion src/jobflow/core/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def as_dict(self):
schema_dict = MontyEncoder().default(schema) if schema is not None else None
data = {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
"@class": type(self).__name__,
"@version": None,
"uuid": self.uuid,
"attributes": self.attributes,
Expand Down
Loading

0 comments on commit 10aad48

Please sign in to comment.