From 9a2749fb025467b2b532c491ad8b54c13190895b Mon Sep 17 00:00:00 2001 From: Jacob Hall Date: Thu, 18 Jul 2024 17:21:01 -0400 Subject: [PATCH] improve docstrings, type hints in data_manager --- .../src/data_manager/configuration.py | 52 +++++++- data_manager/src/data_manager/dataset.py | 115 ++++++++++++++---- 2 files changed, 141 insertions(+), 26 deletions(-) diff --git a/data_manager/src/data_manager/configuration.py b/data_manager/src/data_manager/configuration.py index 52afee9..e3d5c06 100644 --- a/data_manager/src/data_manager/configuration.py +++ b/data_manager/src/data_manager/configuration.py @@ -1,6 +1,5 @@ import logging import tomllib -from configparser import ConfigParser from pathlib import Path from typing import Literal, Optional, Union @@ -13,6 +12,7 @@ class RunParameters(BaseModel): parameters for a Dataset. This model is consumed by Dataset.run() as settings for how to run the Dataset. """ + backend: Literal["local", "mpi", "prefect"] = "prefect" task_runner: Literal[ "concurrent", @@ -21,17 +21,58 @@ class RunParameters(BaseModel): "kubernetes", "sequential", ] = "concurrent" + """ + The backend to run the dataset on. + Most common values are "sequential", and "concurrent" + """ run_parallel: bool = True + """ + Whether or not to run the Dataset in parallel. + """ max_workers: Optional[int] = 4 + """ + Maximum number of concurrent tasks that may be run for this Dataset. + This may be overridden when calling `Dataset.run_tasks()` + """ bypass_error_wrapper: bool = False + """ + If set to `True`, exceptions will not be caught when running tasks, and will instead stop execution of the entire dataset. + This can be helpful for quickly debugging a dataset, especially when it is running sequentially. + """ threads_per_worker: Optional[int] = 1 + """ + `threads_per_worker` passed through to the DaskCluster when using the dask task runner. + """ # cores_per_process: Optional[int] = None chunksize: int = 1 + """ + Sets the chunksize for pools created for concurrent or MPI task runners. + """ log_dir: str + """ + Path to directory where logs for this Dataset run should be saved. + This is the only run parameter without a default, so it must be set in a Dataset's configuration file. + """ logger_level: int = logging.INFO + """ + Minimum log level to log. + For more information, see the [relevant Python documentation](https://docs.python.org/3/library/logging.html#logging-levels). + """ retries: int = 3 + """ + Number of times to retry each task before giving up. + This parameter can be overridden per task run when calling `Dataset.run_tasks()` + """ retry_delay: int = 5 + """ + Time in seconds to wait between task retries. + This parameter can be overridden per task run when calling `Dataset.run_tasks()` + """ conda_env: str = "geodata38" + """ + Conda environment to use when running the dataset. + **Deprecated because we do not use this in the new Prefect/Kubernetes setup** + """ class BaseDatasetConfiguration(BaseModel): @@ -42,7 +83,12 @@ class BaseDatasetConfiguration(BaseModel): Common examples are `overwrite_download`, `overwrite_processing`, or `year_list`. """ + run: RunParameters + """ + A `RunParameters` model that defines how this model should be run. + This is passed into the `Dataset.run()` function. + """ def get_config( @@ -57,6 +103,10 @@ def get_config( returns a `BaseDatasetConfiguration` model filled in with the values from that configuration file. + + Parameters: + model: The model to load the configuration values into. This should nearly always be a Dataset-specific model defined in `main.py` that inherits `BaseDatasetConfiguration. + config_path: The relative path to the TOML configuration file. It's unlikely this parameter should ever be changed from its default. """ config_path = Path(config_path) if config_path.exists(): diff --git a/data_manager/src/data_manager/dataset.py b/data_manager/src/data_manager/dataset.py index 383a68e..b8af7f0 100644 --- a/data_manager/src/data_manager/dataset.py +++ b/data_manager/src/data_manager/dataset.py @@ -31,7 +31,10 @@ class ResultTuple(Sequence): """ This is an immutable sequence designed to hold TaskResults It also keeps track of the name of a run and the time it started - ResultTuple.results() returns a list of results from each task + ResultTuple.results() returns a list of results from each task. + + Inherits the `Sequence` class, and therefore provides methods + such as `__len__` and `__getitem__`. """ def __init__( @@ -40,6 +43,12 @@ def __init__( name: str, timestamp: datetime = datetime.today(), ): + """ + Parameters: + iterable: Itererable of `TaskResult`s to store. + name: Name of this `ResultTuple`. + timestamp: Timestamp of the task run that produced these results. + """ self.elements = [] for value in iterable: if isinstance(value, TaskResult): @@ -54,10 +63,13 @@ def __init__( def __getitem__(self, key: int): return self.elements[key] - def __len__(self): + def __len__(self) -> int: + """ + Returns the number of results in this `ResultTuple`. + """ return len(self.elements) - def __repr__(self): + def __repr__(self) -> str: success_count = sum(1 for t in self.elements if t.status_code == 0) error_count = len(self.elements) - success_count return f'' @@ -81,7 +93,7 @@ def results(self): class Dataset(ABC): """ - This is the base class for Datasets, providing functions that manage task runs and logs + This is the base class for Datasets, providing functions that manage task runs and logs. """ backend: str @@ -119,6 +131,31 @@ def tmp_to_dst_file( tmp_dir: Optional[str | os.PathLike] = None, validate_cog: bool = False, ): + """ + Context manager that provides a temporary file path to write + output files to, that is automatically moved to a final destination + once the context is exited. This prevents interrupted jobs + from leaving partially-written files in the filesystem where + they might be mistaken for complete files. + + Additionally, this context manager can create output directories + that don't exist yet, or validate COG files after they've been + written. See the list of parameters below for more information. + + Here is an example of its use: + + ```python + with self.tmp_to_dst_file(final_dst, validate_cog=True) as tmp_dst: + with rasterio.open(tmp_dst, "w", **meta) as dst: + ... + ``` + + Parameters: + final_dst: Path to where the file should be written. + make_dst_dir: If set to true, the parent directory of `final_dst` will be created (and any of its parents, as necessary) + tmp_dir: Path to directory where file should be temporarily stored. If set to `None`, a default directory will be used. + validate_cog: If set to `True`, the written file will be validated as a COG, and an exception will be raised if this validation fails. + """ logger = self.get_logger() final_dst = Path(final_dst) @@ -193,7 +230,9 @@ def error_wrapper(self, func: Callable, args: Dict[str, Any]): logger.error(f"Task failed with exception (giving up): {repr(e)}") return TaskResult(1, repr(e), args, None) - def run_serial_tasks(self, name, func: Callable, input_list: Iterable[Any]): + def run_serial_tasks( + self, name, func: Callable, input_list: Iterable[Dict[str, Any]] + ): """ Run tasks in serial (locally), given a function and list of inputs This will always return a list of TaskResults! @@ -206,7 +245,7 @@ def run_concurrent_tasks( self, name: str, func: Callable, - input_list: Iterable[Any], + input_list: Iterable[Dict[str, Any]], force_sequential: bool, max_workers: int = None, ): @@ -228,7 +267,7 @@ def run_prefect_tasks( self, name: str, func: Callable, - input_list: Iterable[Any], + input_list: Iterable[Dict[str, Any]], force_sequential: bool, prefect_concurrency_tag: str = None, prefect_concurrency_task_value: int = 1, @@ -245,7 +284,9 @@ def run_prefect_tasks( def cfunc(wrapper_args, func_args): func, prefect_concurrency_tag, prefect_concurrency_task_value = wrapper_args - with concurrency(prefect_concurrency_tag, occupy=prefect_concurrency_task_value): + with concurrency( + prefect_concurrency_tag, occupy=prefect_concurrency_task_value + ): return func(*func_args) if not prefect_concurrency_tag: @@ -263,16 +304,21 @@ def cfunc(wrapper_args, func_args): retries=self.retries, retry_delay_seconds=self.retry_delay, persist_result=True, - ) + ) futures = [] for i in input_list: w = [f[1] for f in futures] if force_sequential else None if prefect_concurrency_tag: - args = ((func, prefect_concurrency_tag, prefect_concurrency_task_value), i) + args = ( + (func, prefect_concurrency_tag, prefect_concurrency_task_value), + i, + ) else: args = i - futures.append((args, task_wrapper.submit(*args, wait_for=w, return_state=False))) + futures.append( + (args, task_wrapper.submit(*args, wait_for=w, return_state=False)) + ) results = [] @@ -341,7 +387,7 @@ def run_mpi_tasks( self, name: str, func: Callable, - input_list: Iterable[Any], + input_list: Iterable[Dict[str, Any]], force_sequential: bool, max_workers: int = None, ): @@ -355,9 +401,7 @@ def run_mpi_tasks( if not max_workers: max_workers = self.mpi_max_workers - with MPIPoolExecutor( - max_workers=max_workers, chunksize=self.chunksize - ) as pool: + with MPIPoolExecutor(max_workers=max_workers, chunksize=self.chunksize) as pool: futures = [] for i in input_list: f = pool.submit(self.error_wrapper, func, i) @@ -369,21 +413,32 @@ def run_mpi_tasks( def run_tasks( self, func: Callable, - input_list: Iterable[Any], - allow_futures: bool = True, + input_list: Iterable[Dict[str, Any]], name: Optional[str] = None, retries: int = 3, retry_delay: int = 60, force_sequential: bool = False, force_serial: bool = False, max_workers: Optional[int] = None, - prefect_concurrency_tag: str = None, - prefect_concurrency_task_value: int = None, - ): + prefect_concurrency_tag: Optional[str] = None, + prefect_concurrency_task_value: Optional[int] = None, + ) -> ResultTuple: """ Run a bunch of tasks, calling one of the above run_tasks functions This is the function that should be called most often from self.main() It will return a ResultTuple of TaskResults + + Parameters: + func: The function to run for each task. + input_list: An iterable of function inputs. For each input, a new task will be created with that input passed as the only parameter. + name: A name for this task run, for easier reference. + retries: Number of times to retry a task before giving up. + retry_delay: Delay (in seconds) to wait between task retries. + force_sequential: If set to `True`, all tasks in this run will be run in sequence, regardless of backend. + force_serial: If set to `True`, all tasks will be run locally (using the internal "serial runner") rather than with this Dataset's usual backend. **Please avoid using this parameter, it will likely be deprecated soon!** + max_workers: Maximum number of tasks to run at once, if using a concurrent mode. This value will not override `force_sequential` or `force_serial`. **Warning: This is not yet supported by the Prefect backend. We hope to fix this soon.** + prefect_concurrency_tag: If using the Prefect backend, this tag will be used to limit the concurrency of this task. **This will eventually be deprecated in favor of `max_workers` once we have implemented that for the Prefect backend.** + prefect_concurrency_task_value: If using the Prefect backend, this sets the maximum number of tasks to run at once, similar to `max_workers`. **See warning above.** """ timestamp = datetime.today() @@ -408,7 +463,6 @@ def run_tasks( elif not isinstance(name, str): raise TypeError("Name of task run must be a string") - if max_workers is None and hasattr(self, "max_workers"): max_workers = self.max_workers @@ -419,10 +473,19 @@ def run_tasks( name, func, input_list, force_sequential, max_workers=max_workers ) elif self.backend == "prefect": - results = self.run_prefect_tasks(name, func, input_list, force_sequential, prefect_concurrency_tag, prefect_concurrency_task_value) + results = self.run_prefect_tasks( + name, + func, + input_list, + force_sequential, + prefect_concurrency_tag, + prefect_concurrency_task_value, + ) elif self.backend == "mpi": - results = self.run_mpi_tasks(name, func, input_list, force_sequential, max_workers=max_workers) + results = self.run_mpi_tasks( + name, func, input_list, force_sequential, max_workers=max_workers + ) else: raise ValueError( "Requested backend not recognized. Have you called this Dataset's run function?" @@ -626,8 +689,10 @@ def run( self.backend = "prefect" from prefect import flow - from prefect.task_runners import SequentialTaskRunner, ConcurrentTaskRunner#, ThreadPoolTaskRunner - + from prefect.task_runners import ( # , ThreadPoolTaskRunner + ConcurrentTaskRunner, + SequentialTaskRunner, + ) if params.task_runner == "sequential": tr = SequentialTaskRunner