Skip to content

Commit

Permalink
improve docstrings, type hints in data_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobwhall committed Jul 18, 2024
1 parent 66ea598 commit 9a2749f
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 26 deletions.
52 changes: 51 additions & 1 deletion data_manager/src/data_manager/configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import tomllib
from configparser import ConfigParser
from pathlib import Path
from typing import Literal, Optional, Union

Expand All @@ -13,6 +12,7 @@ class RunParameters(BaseModel):
parameters for a Dataset. This model is consumed by
Dataset.run() as settings for how to run the Dataset.
"""

backend: Literal["local", "mpi", "prefect"] = "prefect"
task_runner: Literal[
"concurrent",
Expand All @@ -21,17 +21,58 @@ class RunParameters(BaseModel):
"kubernetes",
"sequential",
] = "concurrent"
"""
The backend to run the dataset on.
Most common values are "sequential", and "concurrent"
"""
run_parallel: bool = True
"""
Whether or not to run the Dataset in parallel.
"""
max_workers: Optional[int] = 4
"""
Maximum number of concurrent tasks that may be run for this Dataset.
This may be overridden when calling `Dataset.run_tasks()`
"""
bypass_error_wrapper: bool = False
"""
If set to `True`, exceptions will not be caught when running tasks, and will instead stop execution of the entire dataset.
This can be helpful for quickly debugging a dataset, especially when it is running sequentially.
"""
threads_per_worker: Optional[int] = 1
"""
`threads_per_worker` passed through to the DaskCluster when using the dask task runner.
"""
# cores_per_process: Optional[int] = None
chunksize: int = 1
"""
Sets the chunksize for pools created for concurrent or MPI task runners.
"""
log_dir: str
"""
Path to directory where logs for this Dataset run should be saved.
This is the only run parameter without a default, so it must be set in a Dataset's configuration file.
"""
logger_level: int = logging.INFO
"""
Minimum log level to log.
For more information, see the [relevant Python documentation](https://docs.python.org/3/library/logging.html#logging-levels).
"""
retries: int = 3
"""
Number of times to retry each task before giving up.
This parameter can be overridden per task run when calling `Dataset.run_tasks()`
"""
retry_delay: int = 5
"""
Time in seconds to wait between task retries.
This parameter can be overridden per task run when calling `Dataset.run_tasks()`
"""
conda_env: str = "geodata38"
"""
Conda environment to use when running the dataset.
**Deprecated because we do not use this in the new Prefect/Kubernetes setup**
"""


class BaseDatasetConfiguration(BaseModel):
Expand All @@ -42,7 +83,12 @@ class BaseDatasetConfiguration(BaseModel):
Common examples are `overwrite_download`,
`overwrite_processing`, or `year_list`.
"""

run: RunParameters
"""
A `RunParameters` model that defines how this model should be run.
This is passed into the `Dataset.run()` function.
"""


def get_config(
Expand All @@ -57,6 +103,10 @@ def get_config(
returns a `BaseDatasetConfiguration` model
filled in with the values from that
configuration file.
Parameters:
model: The model to load the configuration values into. This should nearly always be a Dataset-specific model defined in `main.py` that inherits `BaseDatasetConfiguration.
config_path: The relative path to the TOML configuration file. It's unlikely this parameter should ever be changed from its default.
"""
config_path = Path(config_path)
if config_path.exists():
Expand Down
115 changes: 90 additions & 25 deletions data_manager/src/data_manager/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class ResultTuple(Sequence):
"""
This is an immutable sequence designed to hold TaskResults
It also keeps track of the name of a run and the time it started
ResultTuple.results() returns a list of results from each task
ResultTuple.results() returns a list of results from each task.
Inherits the `Sequence` class, and therefore provides methods
such as `__len__` and `__getitem__`.
"""

def __init__(
Expand All @@ -40,6 +43,12 @@ def __init__(
name: str,
timestamp: datetime = datetime.today(),
):
"""
Parameters:
iterable: Itererable of `TaskResult`s to store.
name: Name of this `ResultTuple`.
timestamp: Timestamp of the task run that produced these results.
"""
self.elements = []
for value in iterable:
if isinstance(value, TaskResult):
Expand All @@ -54,10 +63,13 @@ def __init__(
def __getitem__(self, key: int):
return self.elements[key]

def __len__(self):
def __len__(self) -> int:
"""
Returns the number of results in this `ResultTuple`.
"""
return len(self.elements)

def __repr__(self):
def __repr__(self) -> str:
success_count = sum(1 for t in self.elements if t.status_code == 0)
error_count = len(self.elements) - success_count
return f'<ResultTuple named "{self.name}" with {success_count} successes, {error_count} errors>'
Expand All @@ -81,7 +93,7 @@ def results(self):

class Dataset(ABC):
"""
This is the base class for Datasets, providing functions that manage task runs and logs
This is the base class for Datasets, providing functions that manage task runs and logs.
"""

backend: str
Expand Down Expand Up @@ -119,6 +131,31 @@ def tmp_to_dst_file(
tmp_dir: Optional[str | os.PathLike] = None,
validate_cog: bool = False,
):
"""
Context manager that provides a temporary file path to write
output files to, that is automatically moved to a final destination
once the context is exited. This prevents interrupted jobs
from leaving partially-written files in the filesystem where
they might be mistaken for complete files.
Additionally, this context manager can create output directories
that don't exist yet, or validate COG files after they've been
written. See the list of parameters below for more information.
Here is an example of its use:
```python
with self.tmp_to_dst_file(final_dst, validate_cog=True) as tmp_dst:
with rasterio.open(tmp_dst, "w", **meta) as dst:
...
```
Parameters:
final_dst: Path to where the file should be written.
make_dst_dir: If set to true, the parent directory of `final_dst` will be created (and any of its parents, as necessary)
tmp_dir: Path to directory where file should be temporarily stored. If set to `None`, a default directory will be used.
validate_cog: If set to `True`, the written file will be validated as a COG, and an exception will be raised if this validation fails.
"""
logger = self.get_logger()

final_dst = Path(final_dst)
Expand Down Expand Up @@ -193,7 +230,9 @@ def error_wrapper(self, func: Callable, args: Dict[str, Any]):
logger.error(f"Task failed with exception (giving up): {repr(e)}")
return TaskResult(1, repr(e), args, None)

def run_serial_tasks(self, name, func: Callable, input_list: Iterable[Any]):
def run_serial_tasks(
self, name, func: Callable, input_list: Iterable[Dict[str, Any]]
):
"""
Run tasks in serial (locally), given a function and list of inputs
This will always return a list of TaskResults!
Expand All @@ -206,7 +245,7 @@ def run_concurrent_tasks(
self,
name: str,
func: Callable,
input_list: Iterable[Any],
input_list: Iterable[Dict[str, Any]],
force_sequential: bool,
max_workers: int = None,
):
Expand All @@ -228,7 +267,7 @@ def run_prefect_tasks(
self,
name: str,
func: Callable,
input_list: Iterable[Any],
input_list: Iterable[Dict[str, Any]],
force_sequential: bool,
prefect_concurrency_tag: str = None,
prefect_concurrency_task_value: int = 1,
Expand All @@ -245,7 +284,9 @@ def run_prefect_tasks(

def cfunc(wrapper_args, func_args):
func, prefect_concurrency_tag, prefect_concurrency_task_value = wrapper_args
with concurrency(prefect_concurrency_tag, occupy=prefect_concurrency_task_value):
with concurrency(
prefect_concurrency_tag, occupy=prefect_concurrency_task_value
):
return func(*func_args)

if not prefect_concurrency_tag:
Expand All @@ -263,16 +304,21 @@ def cfunc(wrapper_args, func_args):
retries=self.retries,
retry_delay_seconds=self.retry_delay,
persist_result=True,
)
)

futures = []
for i in input_list:
w = [f[1] for f in futures] if force_sequential else None
if prefect_concurrency_tag:
args = ((func, prefect_concurrency_tag, prefect_concurrency_task_value), i)
args = (
(func, prefect_concurrency_tag, prefect_concurrency_task_value),
i,
)
else:
args = i
futures.append((args, task_wrapper.submit(*args, wait_for=w, return_state=False)))
futures.append(
(args, task_wrapper.submit(*args, wait_for=w, return_state=False))
)

results = []

Expand Down Expand Up @@ -341,7 +387,7 @@ def run_mpi_tasks(
self,
name: str,
func: Callable,
input_list: Iterable[Any],
input_list: Iterable[Dict[str, Any]],
force_sequential: bool,
max_workers: int = None,
):
Expand All @@ -355,9 +401,7 @@ def run_mpi_tasks(
if not max_workers:
max_workers = self.mpi_max_workers

with MPIPoolExecutor(
max_workers=max_workers, chunksize=self.chunksize
) as pool:
with MPIPoolExecutor(max_workers=max_workers, chunksize=self.chunksize) as pool:
futures = []
for i in input_list:
f = pool.submit(self.error_wrapper, func, i)
Expand All @@ -369,21 +413,32 @@ def run_mpi_tasks(
def run_tasks(
self,
func: Callable,
input_list: Iterable[Any],
allow_futures: bool = True,
input_list: Iterable[Dict[str, Any]],
name: Optional[str] = None,
retries: int = 3,
retry_delay: int = 60,
force_sequential: bool = False,
force_serial: bool = False,
max_workers: Optional[int] = None,
prefect_concurrency_tag: str = None,
prefect_concurrency_task_value: int = None,
):
prefect_concurrency_tag: Optional[str] = None,
prefect_concurrency_task_value: Optional[int] = None,
) -> ResultTuple:
"""
Run a bunch of tasks, calling one of the above run_tasks functions
This is the function that should be called most often from self.main()
It will return a ResultTuple of TaskResults
Parameters:
func: The function to run for each task.
input_list: An iterable of function inputs. For each input, a new task will be created with that input passed as the only parameter.
name: A name for this task run, for easier reference.
retries: Number of times to retry a task before giving up.
retry_delay: Delay (in seconds) to wait between task retries.
force_sequential: If set to `True`, all tasks in this run will be run in sequence, regardless of backend.
force_serial: If set to `True`, all tasks will be run locally (using the internal "serial runner") rather than with this Dataset's usual backend. **Please avoid using this parameter, it will likely be deprecated soon!**
max_workers: Maximum number of tasks to run at once, if using a concurrent mode. This value will not override `force_sequential` or `force_serial`. **Warning: This is not yet supported by the Prefect backend. We hope to fix this soon.**
prefect_concurrency_tag: If using the Prefect backend, this tag will be used to limit the concurrency of this task. **This will eventually be deprecated in favor of `max_workers` once we have implemented that for the Prefect backend.**
prefect_concurrency_task_value: If using the Prefect backend, this sets the maximum number of tasks to run at once, similar to `max_workers`. **See warning above.**
"""

timestamp = datetime.today()
Expand All @@ -408,7 +463,6 @@ def run_tasks(
elif not isinstance(name, str):
raise TypeError("Name of task run must be a string")


if max_workers is None and hasattr(self, "max_workers"):
max_workers = self.max_workers

Expand All @@ -419,10 +473,19 @@ def run_tasks(
name, func, input_list, force_sequential, max_workers=max_workers
)
elif self.backend == "prefect":
results = self.run_prefect_tasks(name, func, input_list, force_sequential, prefect_concurrency_tag, prefect_concurrency_task_value)
results = self.run_prefect_tasks(
name,
func,
input_list,
force_sequential,
prefect_concurrency_tag,
prefect_concurrency_task_value,
)

elif self.backend == "mpi":
results = self.run_mpi_tasks(name, func, input_list, force_sequential, max_workers=max_workers)
results = self.run_mpi_tasks(
name, func, input_list, force_sequential, max_workers=max_workers
)
else:
raise ValueError(
"Requested backend not recognized. Have you called this Dataset's run function?"
Expand Down Expand Up @@ -626,8 +689,10 @@ def run(
self.backend = "prefect"

from prefect import flow
from prefect.task_runners import SequentialTaskRunner, ConcurrentTaskRunner#, ThreadPoolTaskRunner

from prefect.task_runners import ( # , ThreadPoolTaskRunner
ConcurrentTaskRunner,
SequentialTaskRunner,
)

if params.task_runner == "sequential":
tr = SequentialTaskRunner
Expand Down

0 comments on commit 9a2749f

Please sign in to comment.