Skip to content

Commit

Permalink
Enable std tqdm bar, refactor for re-use.
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu committed Jun 5, 2024
1 parent c5e6095 commit b745a40
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 54 deletions.
5 changes: 1 addition & 4 deletions src/hipscat_import/catalog/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import healpy as hp
from hipscat.io import FilePointer, file_io
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from tqdm.auto import tqdm

from hipscat_import.catalog.sparse_histogram import SparseHistogram
from hipscat_import.pipeline_resume_plan import PipelineResumePlan
Expand Down Expand Up @@ -40,9 +39,7 @@ def __post_init__(self):

def gather_plan(self):
"""Initialize the plan."""
with tqdm(
total=5, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar
) as step_progress:
with self.print_progress(total=5, stage_name="Planning") as step_progress:
## Make sure it's safe to use existing resume state.
super().safe_to_resume()
step_progress.update(1)
Expand Down
10 changes: 2 additions & 8 deletions src/hipscat_import/catalog/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
from hipscat.catalog import PartitionInfo
from hipscat.io import paths
from hipscat.io.parquet_metadata import write_parquet_metadata
from tqdm.auto import tqdm

import hipscat_import.catalog.map_reduce as mr
from hipscat_import.catalog.arguments import ImportArguments
from hipscat_import.pipeline_resume_plan import PipelineResumePlan


def _map_pixels(args, client):
Expand Down Expand Up @@ -114,9 +112,7 @@ def run(args, client):
raise ValueError("args must be type ImportArguments")
_map_pixels(args, client)

with tqdm(
total=2, desc=PipelineResumePlan.get_formatted_stage_name("Binning"), disable=not args.progress_bar
) as step_progress:
with args.resume_plan.print_progress(total=2, stage_name="Binning") as step_progress:
raw_histogram = args.resume_plan.read_histogram(args.mapping_healpix_order)
step_progress.update(1)
if args.constant_healpix_order >= 0:
Expand Down Expand Up @@ -153,9 +149,7 @@ def run(args, client):
_reduce_pixels(args, destination_pixel_map, client)

# All done - write out the metadata
with tqdm(
total=5, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar
) as step_progress:
with args.resume_plan.print_progress(total=5, stage_name="Finishing") as step_progress:
catalog_info = args.to_catalog_info(int(raw_histogram.sum()))
io.write_provenance_info(
catalog_base_dir=args.catalog_path,
Expand Down
10 changes: 6 additions & 4 deletions src/hipscat_import/index/run_index.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Create columnar index of hipscat table using dask for parallelization"""

from hipscat.io import file_io, parquet_metadata, write_metadata
from tqdm.auto import tqdm

import hipscat_import.index.map_reduce as mr
from hipscat_import.index.arguments import IndexArguments
from hipscat_import.pipeline_resume_plan import PipelineResumePlan
from hipscat_import.pipeline_resume_plan import print_progress


def run(args, client):
Expand All @@ -17,8 +16,11 @@ def run(args, client):
rows_written = mr.create_index(args, client)

# All done - write out the metadata
with tqdm(
total=4, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar
with print_progress(
total=4,
stage_name="Finishing",
use_progress_bar=args.progress_bar,
simple_progress_bar=args.simple_progress_bar,
) as step_progress:
index_catalog_info = args.to_catalog_info(int(rows_written))
write_metadata.write_provenance_info(
Expand Down
3 changes: 1 addition & 2 deletions src/hipscat_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from hipscat.catalog import PartitionInfo
from hipscat.io import file_io, parquet_metadata, paths, write_metadata
from tqdm.auto import tqdm

import hipscat_import.margin_cache.margin_cache_map_reduce as mcmr
from hipscat_import.margin_cache.margin_cache_resume_plan import MarginCachePlan
Expand Down Expand Up @@ -59,7 +58,7 @@ def generate_margin_cache(args, client):
)
resume_plan.wait_for_reducing(futures)

with tqdm(total=4, desc="Finishing", disable=not args.progress_bar) as step_progress:
with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress:
parquet_metadata.write_parquet_metadata(
args.catalog_path, storage_options=args.output_storage_options
)
Expand Down
6 changes: 2 additions & 4 deletions src/hipscat_import/margin_cache/margin_cache_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from hipscat import pixel_math
from hipscat.io import file_io
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from tqdm.auto import tqdm

from hipscat_import.margin_cache.margin_cache_arguments import MarginCacheArguments
from hipscat_import.pipeline_resume_plan import PipelineResumePlan
Expand All @@ -33,16 +32,15 @@ def __init__(self, args: MarginCacheArguments):
super().__init__(
resume=args.resume,
progress_bar=args.progress_bar,
simple_progress_bar=args.simple_progress_bar,
tmp_path=args.tmp_path,
delete_resume_log_files=args.delete_resume_log_files,
)
self._gather_plan(args)

def _gather_plan(self, args):
"""Initialize the plan."""
with tqdm(
total=3, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar
) as step_progress:
with self.print_progress(total=3, stage_name="Planning") as step_progress:
## Make sure it's safe to use existing resume state.
super().safe_to_resume()
mapping_done = self.is_mapping_done()
Expand Down
89 changes: 72 additions & 17 deletions src/hipscat_import/pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from dask.distributed import print as dask_print
from hipscat.io import FilePointer, file_io
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from tqdm.auto import tqdm

from tqdm.auto import tqdm as auto_tqdm
from tqdm.std import tqdm as std_tqdm


@dataclass
Expand All @@ -25,6 +27,10 @@ class PipelineResumePlan:
progress_bar: bool = True
"""if true, a tqdm progress bar will be displayed for user
feedback of planning progress"""
simple_progress_bar: bool = False
"""if displaying a progress bar, use a text-only simple progress
bar instead of widget. this can be useful in some environments when running
in a notebook where ipywidgets cannot be used (see `progress_bar` argument)"""
delete_resume_log_files: bool = True
"""should we delete task-level done files once each stage is complete?
if False, we will keep all sub-histograms from the mapping stage, and all
Expand Down Expand Up @@ -131,13 +137,7 @@ def wait_for_futures(self, futures, stage_name, fail_fast=False):
RuntimeError: if any future returns an error status.
"""
some_error = False
formatted_stage_name = self.get_formatted_stage_name(stage_name)
for future in tqdm(
as_completed(futures),
desc=formatted_stage_name,
total=len(futures),
disable=(not self.progress_bar),
):
for future in self.print_progress(as_completed(futures), stage_name=stage_name, total=len(futures)):
if future.status == "error":
some_error = True
if fail_fast:
Expand All @@ -146,18 +146,26 @@ def wait_for_futures(self, futures, stage_name, fail_fast=False):
if some_error:
raise RuntimeError(f"Some {stage_name} stages failed. See logs for details.")

@staticmethod
def get_formatted_stage_name(stage_name) -> str:
"""Create a stage name of consistent minimum length. Ensures that the tqdm
progress bars can line up nicely when multiple stages must run.
def print_progress(self, iterable=None, total=None, stage_name=None):
"""Create a progress bar that will provide user with task feedback.
This is a thin wrapper around the static ``print_progress`` method that uses
member variables for the caller's convenience.
Args:
stage_name (str): name of the stage (e.g. mapping, reducing)
iterable (iterable): Optional. provides iterations to progress updates.
total (int): Optional. Expected iterations.
stage_name (str): name of the stage (e.g. mapping, reducing). this will
be further formatted with ``get_formatted_stage_name``, so the caller
doesn't need to worry about that.
"""
if stage_name is None or len(stage_name) == 0:
stage_name = "progress"

return f"{stage_name.capitalize(): <10}"
return print_progress(
iterable=iterable,
total=total,
stage_name=stage_name,
use_progress_bar=self.progress_bar,
simple_progress_bar=self.simple_progress_bar,
)

def check_original_input_paths(self, input_paths):
"""Validate that we're operating on the same file set as the original pipeline,
Expand Down Expand Up @@ -230,3 +238,50 @@ def print_task_failure(custom_message, exception):
except Exception: # pylint: disable=broad-exception-caught
pass
dask_print(exception)


def get_formatted_stage_name(stage_name) -> str:
"""Create a stage name of consistent minimum length. Ensures that the tqdm
progress bars can line up nicely when multiple stages must run.
Args:
stage_name (str): name of the stage (e.g. mapping, reducing)
"""
if stage_name is None or len(stage_name) == 0:
stage_name = "progress"

return f"{stage_name.capitalize(): <10}"


def print_progress(
iterable=None, total=None, stage_name=None, use_progress_bar=True, simple_progress_bar=False
):
"""Create a progress bar that will provide user with task feedback.
Args:
iterable (iterable): Optional. provides iterations to progress updates.
total (int): Optional. Expected iterations.
stage_name (str): name of the stage (e.g. mapping, reducing). this will
be further formatted with ``get_formatted_stage_name``, so the caller
doesn't need to worry about that.
use_progress_bar (bool): should we display any progress. typically False
when no stdout is expected.
simple_progress_bar (bool): if displaying a progress bar, use a text-only
simple progress bar instead of widget. this can be useful when running
in a particular notebook where ipywidgets cannot be used
(only used when ``use_progress_bar`` is True)
"""
if simple_progress_bar:
return std_tqdm(

Check warning on line 275 in src/hipscat_import/pipeline_resume_plan.py

View check run for this annotation

Codecov / codecov/patch

src/hipscat_import/pipeline_resume_plan.py#L275

Added line #L275 was not covered by tests
iterable,
desc=get_formatted_stage_name(stage_name),
total=total,
disable=not use_progress_bar,
)

return auto_tqdm(
iterable,
desc=get_formatted_stage_name(stage_name),
total=total,
disable=not use_progress_bar,
)
6 changes: 5 additions & 1 deletion src/hipscat_import/runtime_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ class RuntimeArguments:
the pipeline where we left off. If False, we start the import from scratch,
overwriting any content of the output directory."""
progress_bar: bool = True
"""if true, a tqdm progress bar will be displayed for user
"""if true, a progress bar will be displayed for user
feedback of map reduce progress"""
simple_progress_bar: bool = False
"""if displaying a progress bar, use a text-only simple progress
bar instead of widget. this can be useful in some environments when running
in a notebook where ipywidgets cannot be used (see `progress_bar` argument)"""
dask_tmp: str = ""
"""directory for dask worker space. this should be local to
the execution of the pipeline, for speed of reads and writes"""
Expand Down
6 changes: 2 additions & 4 deletions src/hipscat_import/soap/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from hipscat.io import file_io
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from hipscat.pixel_tree import PixelAlignment, align_trees
from tqdm.auto import tqdm

from hipscat_import.pipeline_resume_plan import PipelineResumePlan
from hipscat_import.soap.arguments import SoapArguments
Expand Down Expand Up @@ -39,16 +38,15 @@ def __init__(self, args: SoapArguments):
super().__init__(
resume=args.resume,
progress_bar=args.progress_bar,
simple_progress_bar=args.simple_progress_bar,
tmp_path=args.tmp_path,
delete_resume_log_files=args.delete_resume_log_files,
)
self.gather_plan(args)

def gather_plan(self, args):
"""Initialize the plan."""
with tqdm(
total=3, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar
) as step_progress:
with self.print_progress(total=3, stage_name="Planning") as step_progress:
## Make sure it's safe to use existing resume state.
super().safe_to_resume()
step_progress.update(1)
Expand Down
6 changes: 1 addition & 5 deletions src/hipscat_import/soap/run_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from hipscat.catalog.association_catalog.partition_join_info import PartitionJoinInfo
from hipscat.io import parquet_metadata, paths, write_metadata
from tqdm.auto import tqdm

from hipscat_import.pipeline_resume_plan import PipelineResumePlan
from hipscat_import.soap.arguments import SoapArguments
from hipscat_import.soap.map_reduce import combine_partial_results, count_joins, reduce_joins
from hipscat_import.soap.resume_plan import SoapPlan
Expand Down Expand Up @@ -50,9 +48,7 @@ def run(args, client):
resume_plan.wait_for_reducing(futures)

# All done - write out the metadata
with tqdm(
total=4, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar
) as step_progress:
with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress:
if args.write_leaf_files:
parquet_metadata.write_parquet_metadata(
args.catalog_path,
Expand Down
10 changes: 5 additions & 5 deletions tests/hipscat_import/test_pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy.testing as npt
import pytest

from hipscat_import.pipeline_resume_plan import PipelineResumePlan
from hipscat_import.pipeline_resume_plan import PipelineResumePlan, get_formatted_stage_name


def test_done_key(tmp_path):
Expand Down Expand Up @@ -138,16 +138,16 @@ def error_on_even(argument):

def test_formatted_stage_name():
"""Test that we make pretty stage names for presenting in progress bars"""
formatted = PipelineResumePlan.get_formatted_stage_name(None)
formatted = get_formatted_stage_name(None)
assert formatted == "Progress "

formatted = PipelineResumePlan.get_formatted_stage_name("")
formatted = get_formatted_stage_name("")
assert formatted == "Progress "

formatted = PipelineResumePlan.get_formatted_stage_name("stage")
formatted = get_formatted_stage_name("stage")
assert formatted == "Stage "

formatted = PipelineResumePlan.get_formatted_stage_name("very long stage name")
formatted = get_formatted_stage_name("very long stage name")
assert formatted == "Very long stage name"


Expand Down

0 comments on commit b745a40

Please sign in to comment.