Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable std tqdm bar, refactor for re-use. #330

Merged
merged 4 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/hipscat_import/catalog/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import healpy as hp
from hipscat.io import FilePointer, file_io
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from tqdm.auto import tqdm

from hipscat_import.catalog.sparse_histogram import SparseHistogram
from hipscat_import.pipeline_resume_plan import PipelineResumePlan
Expand Down Expand Up @@ -40,9 +39,7 @@ def __post_init__(self):

def gather_plan(self):
"""Initialize the plan."""
with tqdm(
total=5, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar
) as step_progress:
with self.print_progress(total=5, stage_name="Planning") as step_progress:
## Make sure it's safe to use existing resume state.
super().safe_to_resume()
step_progress.update(1)
Expand Down
10 changes: 2 additions & 8 deletions src/hipscat_import/catalog/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
from hipscat.catalog import PartitionInfo
from hipscat.io import paths
from hipscat.io.parquet_metadata import write_parquet_metadata
from tqdm.auto import tqdm

import hipscat_import.catalog.map_reduce as mr
from hipscat_import.catalog.arguments import ImportArguments
from hipscat_import.pipeline_resume_plan import PipelineResumePlan


def _map_pixels(args, client):
Expand Down Expand Up @@ -114,9 +112,7 @@ def run(args, client):
raise ValueError("args must be type ImportArguments")
_map_pixels(args, client)

with tqdm(
total=2, desc=PipelineResumePlan.get_formatted_stage_name("Binning"), disable=not args.progress_bar
) as step_progress:
with args.resume_plan.print_progress(total=2, stage_name="Binning") as step_progress:
raw_histogram = args.resume_plan.read_histogram(args.mapping_healpix_order)
step_progress.update(1)
if args.constant_healpix_order >= 0:
Expand Down Expand Up @@ -153,9 +149,7 @@ def run(args, client):
_reduce_pixels(args, destination_pixel_map, client)

# All done - write out the metadata
with tqdm(
total=5, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar
) as step_progress:
with args.resume_plan.print_progress(total=5, stage_name="Finishing") as step_progress:
catalog_info = args.to_catalog_info(int(raw_histogram.sum()))
io.write_provenance_info(
catalog_base_dir=args.catalog_path,
Expand Down
10 changes: 6 additions & 4 deletions src/hipscat_import/index/run_index.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Create columnar index of hipscat table using dask for parallelization"""

from hipscat.io import file_io, parquet_metadata, write_metadata
from tqdm.auto import tqdm

import hipscat_import.index.map_reduce as mr
from hipscat_import.index.arguments import IndexArguments
from hipscat_import.pipeline_resume_plan import PipelineResumePlan
from hipscat_import.pipeline_resume_plan import print_progress


def run(args, client):
Expand All @@ -17,8 +16,11 @@ def run(args, client):
rows_written = mr.create_index(args, client)

# All done - write out the metadata
with tqdm(
total=4, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar
with print_progress(
total=4,
stage_name="Finishing",
use_progress_bar=args.progress_bar,
simple_progress_bar=args.simple_progress_bar,
) as step_progress:
index_catalog_info = args.to_catalog_info(int(rows_written))
write_metadata.write_provenance_info(
Expand Down
3 changes: 1 addition & 2 deletions src/hipscat_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from hipscat.catalog import PartitionInfo
from hipscat.io import file_io, parquet_metadata, paths, write_metadata
from tqdm.auto import tqdm

import hipscat_import.margin_cache.margin_cache_map_reduce as mcmr
from hipscat_import.margin_cache.margin_cache_resume_plan import MarginCachePlan
Expand Down Expand Up @@ -59,7 +58,7 @@ def generate_margin_cache(args, client):
)
resume_plan.wait_for_reducing(futures)

with tqdm(total=4, desc="Finishing", disable=not args.progress_bar) as step_progress:
with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress:
parquet_metadata.write_parquet_metadata(
args.catalog_path, storage_options=args.output_storage_options
)
Expand Down
6 changes: 2 additions & 4 deletions src/hipscat_import/margin_cache/margin_cache_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from hipscat import pixel_math
from hipscat.io import file_io
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from tqdm.auto import tqdm

from hipscat_import.margin_cache.margin_cache_arguments import MarginCacheArguments
from hipscat_import.pipeline_resume_plan import PipelineResumePlan
Expand All @@ -33,16 +32,15 @@ def __init__(self, args: MarginCacheArguments):
super().__init__(
resume=args.resume,
progress_bar=args.progress_bar,
simple_progress_bar=args.simple_progress_bar,
tmp_path=args.tmp_path,
delete_resume_log_files=args.delete_resume_log_files,
)
self._gather_plan(args)

def _gather_plan(self, args):
"""Initialize the plan."""
with tqdm(
total=3, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar
) as step_progress:
with self.print_progress(total=3, stage_name="Planning") as step_progress:
## Make sure it's safe to use existing resume state.
super().safe_to_resume()
mapping_done = self.is_mapping_done()
Expand Down
88 changes: 71 additions & 17 deletions src/hipscat_import/pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from dask.distributed import print as dask_print
from hipscat.io import FilePointer, file_io
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from tqdm.auto import tqdm
from tqdm.auto import tqdm as auto_tqdm
from tqdm.std import tqdm as std_tqdm


@dataclass
Expand All @@ -25,6 +26,10 @@ class PipelineResumePlan:
progress_bar: bool = True
"""if true, a tqdm progress bar will be displayed for user
feedback of planning progress"""
simple_progress_bar: bool = False
"""if displaying a progress bar, use a text-only simple progress
bar instead of widget. this can be useful in some environments when running
in a notebook where ipywidgets cannot be used (see `progress_bar` argument)"""
delete_resume_log_files: bool = True
"""should we delete task-level done files once each stage is complete?
if False, we will keep all sub-histograms from the mapping stage, and all
Expand Down Expand Up @@ -131,13 +136,7 @@ def wait_for_futures(self, futures, stage_name, fail_fast=False):
RuntimeError: if any future returns an error status.
"""
some_error = False
formatted_stage_name = self.get_formatted_stage_name(stage_name)
for future in tqdm(
as_completed(futures),
desc=formatted_stage_name,
total=len(futures),
disable=(not self.progress_bar),
):
for future in self.print_progress(as_completed(futures), stage_name=stage_name, total=len(futures)):
if future.status == "error":
some_error = True
if fail_fast:
Expand All @@ -146,18 +145,26 @@ def wait_for_futures(self, futures, stage_name, fail_fast=False):
if some_error:
raise RuntimeError(f"Some {stage_name} stages failed. See logs for details.")

@staticmethod
def get_formatted_stage_name(stage_name) -> str:
"""Create a stage name of consistent minimum length. Ensures that the tqdm
progress bars can line up nicely when multiple stages must run.
def print_progress(self, iterable=None, total=None, stage_name=None):
"""Create a progress bar that will provide user with task feedback.

This is a thin wrapper around the static ``print_progress`` method that uses
member variables for the caller's convenience.

Args:
stage_name (str): name of the stage (e.g. mapping, reducing)
iterable (iterable): Optional. provides iterations to progress updates.
total (int): Optional. Expected iterations.
stage_name (str): name of the stage (e.g. mapping, reducing). this will
be further formatted with ``get_formatted_stage_name``, so the caller
doesn't need to worry about that.
"""
if stage_name is None or len(stage_name) == 0:
stage_name = "progress"

return f"{stage_name.capitalize(): <10}"
return print_progress(
iterable=iterable,
total=total,
stage_name=stage_name,
use_progress_bar=self.progress_bar,
simple_progress_bar=self.simple_progress_bar,
)

def check_original_input_paths(self, input_paths):
"""Validate that we're operating on the same file set as the original pipeline,
Expand Down Expand Up @@ -230,3 +237,50 @@ def print_task_failure(custom_message, exception):
except Exception: # pylint: disable=broad-exception-caught
pass
dask_print(exception)


def get_formatted_stage_name(stage_name) -> str:
"""Create a stage name of consistent minimum length. Ensures that the tqdm
progress bars can line up nicely when multiple stages must run.

Args:
stage_name (str): name of the stage (e.g. mapping, reducing)
"""
if stage_name is None or len(stage_name) == 0:
stage_name = "progress"

return f"{stage_name.capitalize(): <10}"


def print_progress(
iterable=None, total=None, stage_name=None, use_progress_bar=True, simple_progress_bar=False
):
"""Create a progress bar that will provide user with task feedback.

Args:
iterable (iterable): Optional. provides iterations to progress updates.
total (int): Optional. Expected iterations.
stage_name (str): name of the stage (e.g. mapping, reducing). this will
be further formatted with ``get_formatted_stage_name``, so the caller
doesn't need to worry about that.
use_progress_bar (bool): should we display any progress. typically False
when no stdout is expected.
simple_progress_bar (bool): if displaying a progress bar, use a text-only
simple progress bar instead of widget. this can be useful when running
in a particular notebook where ipywidgets cannot be used
(only used when ``use_progress_bar`` is True)
"""
if simple_progress_bar:
return std_tqdm(
iterable,
desc=get_formatted_stage_name(stage_name),
total=total,
disable=not use_progress_bar,
)

return auto_tqdm(
iterable,
desc=get_formatted_stage_name(stage_name),
total=total,
disable=not use_progress_bar,
)
6 changes: 5 additions & 1 deletion src/hipscat_import/runtime_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ class RuntimeArguments:
the pipeline where we left off. If False, we start the import from scratch,
overwriting any content of the output directory."""
progress_bar: bool = True
"""if true, a tqdm progress bar will be displayed for user
"""if true, a progress bar will be displayed for user
feedback of map reduce progress"""
simple_progress_bar: bool = False
"""if displaying a progress bar, use a text-only simple progress
bar instead of widget. this can be useful in some environments when running
in a notebook where ipywidgets cannot be used (see `progress_bar` argument)"""
dask_tmp: str = ""
"""directory for dask worker space. this should be local to
the execution of the pipeline, for speed of reads and writes"""
Expand Down
6 changes: 2 additions & 4 deletions src/hipscat_import/soap/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from hipscat.io import file_io
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from hipscat.pixel_tree import PixelAlignment, align_trees
from tqdm.auto import tqdm

from hipscat_import.pipeline_resume_plan import PipelineResumePlan
from hipscat_import.soap.arguments import SoapArguments
Expand Down Expand Up @@ -39,16 +38,15 @@ def __init__(self, args: SoapArguments):
super().__init__(
resume=args.resume,
progress_bar=args.progress_bar,
simple_progress_bar=args.simple_progress_bar,
tmp_path=args.tmp_path,
delete_resume_log_files=args.delete_resume_log_files,
)
self.gather_plan(args)

def gather_plan(self, args):
"""Initialize the plan."""
with tqdm(
total=3, desc=self.get_formatted_stage_name("Planning"), disable=not self.progress_bar
) as step_progress:
with self.print_progress(total=3, stage_name="Planning") as step_progress:
## Make sure it's safe to use existing resume state.
super().safe_to_resume()
step_progress.update(1)
Expand Down
6 changes: 1 addition & 5 deletions src/hipscat_import/soap/run_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from hipscat.catalog.association_catalog.partition_join_info import PartitionJoinInfo
from hipscat.io import parquet_metadata, paths, write_metadata
from tqdm.auto import tqdm

from hipscat_import.pipeline_resume_plan import PipelineResumePlan
from hipscat_import.soap.arguments import SoapArguments
from hipscat_import.soap.map_reduce import combine_partial_results, count_joins, reduce_joins
from hipscat_import.soap.resume_plan import SoapPlan
Expand Down Expand Up @@ -50,9 +48,7 @@ def run(args, client):
resume_plan.wait_for_reducing(futures)

# All done - write out the metadata
with tqdm(
total=4, desc=PipelineResumePlan.get_formatted_stage_name("Finishing"), disable=not args.progress_bar
) as step_progress:
with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress:
if args.write_leaf_files:
parquet_metadata.write_parquet_metadata(
args.catalog_path,
Expand Down
32 changes: 27 additions & 5 deletions tests/hipscat_import/test_pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy.testing as npt
import pytest

from hipscat_import.pipeline_resume_plan import PipelineResumePlan
from hipscat_import.pipeline_resume_plan import PipelineResumePlan, get_formatted_stage_name


def test_done_key(tmp_path):
Expand Down Expand Up @@ -119,6 +119,28 @@ def error_on_even(argument):
plan.wait_for_futures(futures, "test")


@pytest.mark.dask
def test_wait_for_futures_progress(tmp_path, dask_client, capsys):
"""Test that we can wait around for futures to complete.

Additionally test that relevant parts of the traceback are printed to stdout."""
plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=True, simple_progress_bar=True, resume=False)

def error_on_even(argument):
"""Silly little method used to test futures that fail under predictable conditions"""
if argument % 2 == 0:
raise RuntimeError("we are at odds with evens")

## Everything is fine if we're all odd, but use a silly name so it's
## clear that the stage name is present, and well-formatted.
futures = [dask_client.submit(error_on_even, 1)]
plan.wait_for_futures(futures, "teeeest")

captured = capsys.readouterr()
assert "Teeeest" in captured.err
assert "100%" in captured.err


@pytest.mark.dask
def test_wait_for_futures_fail_fast(tmp_path, dask_client):
"""Test that we can wait around for futures to complete.
Expand All @@ -138,16 +160,16 @@ def error_on_even(argument):

def test_formatted_stage_name():
"""Test that we make pretty stage names for presenting in progress bars"""
formatted = PipelineResumePlan.get_formatted_stage_name(None)
formatted = get_formatted_stage_name(None)
assert formatted == "Progress "

formatted = PipelineResumePlan.get_formatted_stage_name("")
formatted = get_formatted_stage_name("")
assert formatted == "Progress "

formatted = PipelineResumePlan.get_formatted_stage_name("stage")
formatted = get_formatted_stage_name("stage")
assert formatted == "Stage "

formatted = PipelineResumePlan.get_formatted_stage_name("very long stage name")
formatted = get_formatted_stage_name("very long stage name")
assert formatted == "Very long stage name"


Expand Down