From b745a40b993167ee6ab80ba1a9905a22c1388cf9 Mon Sep 17 00:00:00 2001 From: Melissa DeLucchi Date: Wed, 5 Jun 2024 14:39:34 -0400 Subject: [PATCH] Enable std tqdm bar, refactor for re-use. --- src/hipscat_import/catalog/resume_plan.py | 5 +- src/hipscat_import/catalog/run_import.py | 10 +-- src/hipscat_import/index/run_index.py | 10 ++- .../margin_cache/margin_cache.py | 3 +- .../margin_cache/margin_cache_resume_plan.py | 6 +- src/hipscat_import/pipeline_resume_plan.py | 89 +++++++++++++++---- src/hipscat_import/runtime_arguments.py | 6 +- src/hipscat_import/soap/resume_plan.py | 6 +- src/hipscat_import/soap/run_soap.py | 6 +- .../test_pipeline_resume_plan.py | 10 +-- 10 files changed, 97 insertions(+), 54 deletions(-) diff --git a/src/hipscat_import/catalog/resume_plan.py b/src/hipscat_import/catalog/resume_plan.py index 604d8a91..3271d161 100644 --- a/src/hipscat_import/catalog/resume_plan.py +++ b/src/hipscat_import/catalog/resume_plan.py @@ -8,7 +8,6 @@ import healpy as hp from hipscat.io import FilePointer, file_io from hipscat.pixel_math.healpix_pixel import HealpixPixel -from tqdm.auto import tqdm from hipscat_import.catalog.sparse_histogram import SparseHistogram from hipscat_import.pipeline_resume_plan import PipelineResumePlan @@ -40,9 +39,7 @@ def __post_init__(self): def gather_plan(self): """Initialize the plan.""" - with tqdm( - total=5, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar - ) as step_progress: + with self.print_progress(total=5, stage_name="Planning") as step_progress: ## Make sure it's safe to use existing resume state. super().safe_to_resume() step_progress.update(1) diff --git a/src/hipscat_import/catalog/run_import.py b/src/hipscat_import/catalog/run_import.py index 1b97b547..fad8882d 100644 --- a/src/hipscat_import/catalog/run_import.py +++ b/src/hipscat_import/catalog/run_import.py @@ -10,11 +10,9 @@ from hipscat.catalog import PartitionInfo from hipscat.io import paths from hipscat.io.parquet_metadata import write_parquet_metadata -from tqdm.auto import tqdm import hipscat_import.catalog.map_reduce as mr from hipscat_import.catalog.arguments import ImportArguments -from hipscat_import.pipeline_resume_plan import PipelineResumePlan def _map_pixels(args, client): @@ -114,9 +112,7 @@ def run(args, client): raise ValueError("args must be type ImportArguments") _map_pixels(args, client) - with tqdm( - total=2, desc=PipelineResumePlan.get_formatted_stage_name("Binning"), disable=not args.progress_bar - ) as step_progress: + with args.resume_plan.print_progress(total=2, stage_name="Binning") as step_progress: raw_histogram = args.resume_plan.read_histogram(args.mapping_healpix_order) step_progress.update(1) if args.constant_healpix_order >= 0: @@ -153,9 +149,7 @@ def run(args, client): _reduce_pixels(args, destination_pixel_map, client) # All done - write out the metadata - with tqdm( - total=5, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar - ) as step_progress: + with args.resume_plan.print_progress(total=5, stage_name="Finishing") as step_progress: catalog_info = args.to_catalog_info(int(raw_histogram.sum())) io.write_provenance_info( catalog_base_dir=args.catalog_path, diff --git a/src/hipscat_import/index/run_index.py b/src/hipscat_import/index/run_index.py index c4279623..bbf870d1 100644 --- a/src/hipscat_import/index/run_index.py +++ b/src/hipscat_import/index/run_index.py @@ -1,11 +1,10 @@ """Create columnar index of hipscat table using dask for parallelization""" from hipscat.io import file_io, parquet_metadata, write_metadata -from tqdm.auto import tqdm import hipscat_import.index.map_reduce as mr from hipscat_import.index.arguments import IndexArguments -from hipscat_import.pipeline_resume_plan import PipelineResumePlan +from hipscat_import.pipeline_resume_plan import print_progress def run(args, client): @@ -17,8 +16,11 @@ def run(args, client): rows_written = mr.create_index(args, client) # All done - write out the metadata - with tqdm( - total=4, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar + with print_progress( + total=4, + stage_name="Finishing", + use_progress_bar=args.progress_bar, + simple_progress_bar=args.simple_progress_bar, ) as step_progress: index_catalog_info = args.to_catalog_info(int(rows_written)) write_metadata.write_provenance_info( diff --git a/src/hipscat_import/margin_cache/margin_cache.py b/src/hipscat_import/margin_cache/margin_cache.py index 217fcc56..8611393f 100644 --- a/src/hipscat_import/margin_cache/margin_cache.py +++ b/src/hipscat_import/margin_cache/margin_cache.py @@ -1,6 +1,5 @@ from hipscat.catalog import PartitionInfo from hipscat.io import file_io, parquet_metadata, paths, write_metadata -from tqdm.auto import tqdm import hipscat_import.margin_cache.margin_cache_map_reduce as mcmr from hipscat_import.margin_cache.margin_cache_resume_plan import MarginCachePlan @@ -59,7 +58,7 @@ def generate_margin_cache(args, client): ) resume_plan.wait_for_reducing(futures) - with tqdm(total=4, desc="Finishing", disable=not args.progress_bar) as step_progress: + with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress: parquet_metadata.write_parquet_metadata( args.catalog_path, storage_options=args.output_storage_options ) diff --git a/src/hipscat_import/margin_cache/margin_cache_resume_plan.py b/src/hipscat_import/margin_cache/margin_cache_resume_plan.py index 72aec80d..9a6274df 100644 --- a/src/hipscat_import/margin_cache/margin_cache_resume_plan.py +++ b/src/hipscat_import/margin_cache/margin_cache_resume_plan.py @@ -9,7 +9,6 @@ from hipscat import pixel_math from hipscat.io import file_io from hipscat.pixel_math.healpix_pixel import HealpixPixel -from tqdm.auto import tqdm from hipscat_import.margin_cache.margin_cache_arguments import MarginCacheArguments from hipscat_import.pipeline_resume_plan import PipelineResumePlan @@ -33,6 +32,7 @@ def __init__(self, args: MarginCacheArguments): super().__init__( resume=args.resume, progress_bar=args.progress_bar, + simple_progress_bar=args.simple_progress_bar, tmp_path=args.tmp_path, delete_resume_log_files=args.delete_resume_log_files, ) @@ -40,9 +40,7 @@ def __init__(self, args: MarginCacheArguments): def _gather_plan(self, args): """Initialize the plan.""" - with tqdm( - total=3, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar - ) as step_progress: + with self.print_progress(total=3, stage_name="Planning") as step_progress: ## Make sure it's safe to use existing resume state. super().safe_to_resume() mapping_done = self.is_mapping_done() diff --git a/src/hipscat_import/pipeline_resume_plan.py b/src/hipscat_import/pipeline_resume_plan.py index 7e859e91..95cf7688 100644 --- a/src/hipscat_import/pipeline_resume_plan.py +++ b/src/hipscat_import/pipeline_resume_plan.py @@ -10,7 +10,9 @@ from dask.distributed import print as dask_print from hipscat.io import FilePointer, file_io from hipscat.pixel_math.healpix_pixel import HealpixPixel -from tqdm.auto import tqdm + +from tqdm.auto import tqdm as auto_tqdm +from tqdm.std import tqdm as std_tqdm @dataclass @@ -25,6 +27,10 @@ class PipelineResumePlan: progress_bar: bool = True """if true, a tqdm progress bar will be displayed for user feedback of planning progress""" + simple_progress_bar: bool = False + """if displaying a progress bar, use a text-only simple progress + bar instead of widget. this can be useful in some environments when running + in a notebook where ipywidgets cannot be used (see `progress_bar` argument)""" delete_resume_log_files: bool = True """should we delete task-level done files once each stage is complete? if False, we will keep all sub-histograms from the mapping stage, and all @@ -131,13 +137,7 @@ def wait_for_futures(self, futures, stage_name, fail_fast=False): RuntimeError: if any future returns an error status. """ some_error = False - formatted_stage_name = self.get_formatted_stage_name(stage_name) - for future in tqdm( - as_completed(futures), - desc=formatted_stage_name, - total=len(futures), - disable=(not self.progress_bar), - ): + for future in self.print_progress(as_completed(futures), stage_name=stage_name, total=len(futures)): if future.status == "error": some_error = True if fail_fast: @@ -146,18 +146,26 @@ def wait_for_futures(self, futures, stage_name, fail_fast=False): if some_error: raise RuntimeError(f"Some {stage_name} stages failed. See logs for details.") - @staticmethod - def get_formatted_stage_name(stage_name) -> str: - """Create a stage name of consistent minimum length. Ensures that the tqdm - progress bars can line up nicely when multiple stages must run. + def print_progress(self, iterable=None, total=None, stage_name=None): + """Create a progress bar that will provide user with task feedback. + + This is a thin wrapper around the static ``print_progress`` method that uses + member variables for the caller's convenience. Args: - stage_name (str): name of the stage (e.g. mapping, reducing) + iterable (iterable): Optional. provides iterations to progress updates. + total (int): Optional. Expected iterations. + stage_name (str): name of the stage (e.g. mapping, reducing). this will + be further formatted with ``get_formatted_stage_name``, so the caller + doesn't need to worry about that. """ - if stage_name is None or len(stage_name) == 0: - stage_name = "progress" - - return f"{stage_name.capitalize(): <10}" + return print_progress( + iterable=iterable, + total=total, + stage_name=stage_name, + use_progress_bar=self.progress_bar, + simple_progress_bar=self.simple_progress_bar, + ) def check_original_input_paths(self, input_paths): """Validate that we're operating on the same file set as the original pipeline, @@ -230,3 +238,50 @@ def print_task_failure(custom_message, exception): except Exception: # pylint: disable=broad-exception-caught pass dask_print(exception) + + +def get_formatted_stage_name(stage_name) -> str: + """Create a stage name of consistent minimum length. Ensures that the tqdm + progress bars can line up nicely when multiple stages must run. + + Args: + stage_name (str): name of the stage (e.g. mapping, reducing) + """ + if stage_name is None or len(stage_name) == 0: + stage_name = "progress" + + return f"{stage_name.capitalize(): <10}" + + +def print_progress( + iterable=None, total=None, stage_name=None, use_progress_bar=True, simple_progress_bar=False +): + """Create a progress bar that will provide user with task feedback. + + Args: + iterable (iterable): Optional. provides iterations to progress updates. + total (int): Optional. Expected iterations. + stage_name (str): name of the stage (e.g. mapping, reducing). this will + be further formatted with ``get_formatted_stage_name``, so the caller + doesn't need to worry about that. + use_progress_bar (bool): should we display any progress. typically False + when no stdout is expected. + simple_progress_bar (bool): if displaying a progress bar, use a text-only + simple progress bar instead of widget. this can be useful when running + in a particular notebook where ipywidgets cannot be used + (only used when ``use_progress_bar`` is True) + """ + if simple_progress_bar: + return std_tqdm( + iterable, + desc=get_formatted_stage_name(stage_name), + total=total, + disable=not use_progress_bar, + ) + + return auto_tqdm( + iterable, + desc=get_formatted_stage_name(stage_name), + total=total, + disable=not use_progress_bar, + ) diff --git a/src/hipscat_import/runtime_arguments.py b/src/hipscat_import/runtime_arguments.py index 4f4d7126..9ce5e8ad 100644 --- a/src/hipscat_import/runtime_arguments.py +++ b/src/hipscat_import/runtime_arguments.py @@ -32,8 +32,12 @@ class RuntimeArguments: the pipeline where we left off. If False, we start the import from scratch, overwriting any content of the output directory.""" progress_bar: bool = True - """if true, a tqdm progress bar will be displayed for user + """if true, a progress bar will be displayed for user feedback of map reduce progress""" + simple_progress_bar: bool = False + """if displaying a progress bar, use a text-only simple progress + bar instead of widget. this can be useful in some environments when running + in a notebook where ipywidgets cannot be used (see `progress_bar` argument)""" dask_tmp: str = "" """directory for dask worker space. this should be local to the execution of the pipeline, for speed of reads and writes""" diff --git a/src/hipscat_import/soap/resume_plan.py b/src/hipscat_import/soap/resume_plan.py index be1a6f86..3afd8626 100644 --- a/src/hipscat_import/soap/resume_plan.py +++ b/src/hipscat_import/soap/resume_plan.py @@ -11,7 +11,6 @@ from hipscat.io import file_io from hipscat.pixel_math.healpix_pixel import HealpixPixel from hipscat.pixel_tree import PixelAlignment, align_trees -from tqdm.auto import tqdm from hipscat_import.pipeline_resume_plan import PipelineResumePlan from hipscat_import.soap.arguments import SoapArguments @@ -39,6 +38,7 @@ def __init__(self, args: SoapArguments): super().__init__( resume=args.resume, progress_bar=args.progress_bar, + simple_progress_bar=args.simple_progress_bar, tmp_path=args.tmp_path, delete_resume_log_files=args.delete_resume_log_files, ) @@ -46,9 +46,7 @@ def __init__(self, args: SoapArguments): def gather_plan(self, args): """Initialize the plan.""" - with tqdm( - total=3, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar - ) as step_progress: + with self.print_progress(total=3, stage_name="Planning") as step_progress: ## Make sure it's safe to use existing resume state. super().safe_to_resume() step_progress.update(1) diff --git a/src/hipscat_import/soap/run_soap.py b/src/hipscat_import/soap/run_soap.py index dafafae4..34d50b8a 100644 --- a/src/hipscat_import/soap/run_soap.py +++ b/src/hipscat_import/soap/run_soap.py @@ -5,9 +5,7 @@ from hipscat.catalog.association_catalog.partition_join_info import PartitionJoinInfo from hipscat.io import parquet_metadata, paths, write_metadata -from tqdm.auto import tqdm -from hipscat_import.pipeline_resume_plan import PipelineResumePlan from hipscat_import.soap.arguments import SoapArguments from hipscat_import.soap.map_reduce import combine_partial_results, count_joins, reduce_joins from hipscat_import.soap.resume_plan import SoapPlan @@ -50,9 +48,7 @@ def run(args, client): resume_plan.wait_for_reducing(futures) # All done - write out the metadata - with tqdm( - total=4, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar - ) as step_progress: + with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress: if args.write_leaf_files: parquet_metadata.write_parquet_metadata( args.catalog_path, diff --git a/tests/hipscat_import/test_pipeline_resume_plan.py b/tests/hipscat_import/test_pipeline_resume_plan.py index b694f33c..bef75346 100644 --- a/tests/hipscat_import/test_pipeline_resume_plan.py +++ b/tests/hipscat_import/test_pipeline_resume_plan.py @@ -6,7 +6,7 @@ import numpy.testing as npt import pytest -from hipscat_import.pipeline_resume_plan import PipelineResumePlan +from hipscat_import.pipeline_resume_plan import PipelineResumePlan, get_formatted_stage_name def test_done_key(tmp_path): @@ -138,16 +138,16 @@ def error_on_even(argument): def test_formatted_stage_name(): """Test that we make pretty stage names for presenting in progress bars""" - formatted = PipelineResumePlan.get_formatted_stage_name(None) + formatted = get_formatted_stage_name(None) assert formatted == "Progress " - formatted = PipelineResumePlan.get_formatted_stage_name("") + formatted = get_formatted_stage_name("") assert formatted == "Progress " - formatted = PipelineResumePlan.get_formatted_stage_name("stage") + formatted = get_formatted_stage_name("stage") assert formatted == "Stage " - formatted = PipelineResumePlan.get_formatted_stage_name("very long stage name") + formatted = get_formatted_stage_name("very long stage name") assert formatted == "Very long stage name"