Skip to content

Commit

Permalink
Merge pull request #88 from astronomy-commons/delucchi/pipeline
Browse files Browse the repository at this point in the history
Consolidate pipeline logic
  • Loading branch information
delucchi-cmu committed Jun 21, 2023
2 parents 24b0cb4 + 331698a commit f0cd4bf
Show file tree
Hide file tree
Showing 21 changed files with 216 additions and 172 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/testing-and-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ jobs:
run: |
sudo apt-get update
python -m pip install --upgrade pip
pip install .
pip install .[dev]
pip install -e .[dev]
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Run unit tests with pytest
run: |
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/intro_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@
" # Calculate the cosine of each value of X\n",
" z = np.cos(x)\n",
" # Plot the sine wave in blue, using degrees rather than radians on the X axis\n",
" pl.plot(xdeg, y, color='blue', label='Sine wave')\n",
" pl.plot(xdeg, y, color=\"blue\", label=\"Sine wave\")\n",
" # Plot the cos wave in green, using degrees rather than radians on the X axis\n",
" pl.plot(xdeg, z, color='green', label='Cosine wave')\n",
" pl.plot(xdeg, z, color=\"green\", label=\"Cosine wave\")\n",
" pl.xlabel(\"Degrees\")\n",
" # More sensible X axis values\n",
" pl.xticks(np.arange(0, 361, 45))\n",
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ write_to = "src/hipscat_import/_version.py"

[tool.setuptools.package-data]
hipscat_import = ["py.typed"]
[project.scripts]
hipscat-import = "hipscat_import.control:main"
hc = "hipscat_import.control:main"

[tool.pytest.ini_options]
timeout = 1
markers = [
"dask: mark tests as having a dask client runtime dependency",
]

[tool.coverage.report]
omit = [
"src/hipscat_import/_version.py", # auto-generated
"src/hipscat_import/pipeline.py", # too annoying to test
]
11 changes: 3 additions & 8 deletions src/hipscat_import/association/run_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,16 @@
from tqdm import tqdm

from hipscat_import.association.arguments import AssociationArguments
from hipscat_import.association.map_reduce import (map_association,
reduce_association)
from hipscat_import.association.map_reduce import map_association, reduce_association


def _validate_args(args):
def run(args):
"""Run the association pipeline"""
if not args:
raise TypeError("args is required and should be type AssociationArguments")
if not isinstance(args, AssociationArguments):
raise TypeError("args must be type AssociationArguments")


def run(args):
"""Run the association pipeline"""
_validate_args(args)

with tqdm(total=1, desc="Mapping ", disable=not args.progress_bar) as step_progress:
map_association(args)
step_progress.update(1)
Expand Down
31 changes: 22 additions & 9 deletions src/hipscat_import/catalog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
"""All modules for importing new catalogs."""

from .arguments import ImportArguments
from .file_readers import (CsvReader, FitsReader, InputReader, ParquetReader,
get_file_reader)
from .file_readers import (
CsvReader,
FitsReader,
InputReader,
ParquetReader,
get_file_reader,
)
from .map_reduce import map_to_pixels, reduce_pixel_shards, split_pixels
from .resume_files import (clean_resume_files, is_mapping_done,
is_reducing_done, read_histogram, read_mapping_keys,
read_reducing_keys, set_mapping_done,
set_reducing_done, write_histogram,
write_mapping_done_key, write_mapping_start_key,
write_reducing_key)
from .run_import import run, run_with_client
from .resume_files import (
clean_resume_files,
is_mapping_done,
is_reducing_done,
read_histogram,
read_mapping_keys,
read_reducing_keys,
set_mapping_done,
set_reducing_done,
write_histogram,
write_mapping_done_key,
write_mapping_start_key,
write_reducing_key,
)
from .run_import import run
4 changes: 2 additions & 2 deletions src/hipscat_import/catalog/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def map_to_pixels(
Args:
input_file (FilePointer): file to read for catalog data.
file_reader (hipscat_import.catalog.file_readers.InputReader): instance of input
file_reader (hipscat_import.catalog.file_readers.InputReader): instance of input
reader that specifies arguments necessary for reading from the input file.
shard_suffix (str): unique counter for this input file, used
when creating intermediate files
Expand Down Expand Up @@ -137,7 +137,7 @@ def split_pixels(
Args:
input_file (FilePointer): file to read for catalog data.
file_reader (hipscat_import.catalog.file_readers.InputReader): instance
file_reader (hipscat_import.catalog.file_readers.InputReader): instance
of input reader that specifies arguments necessary for reading from the input file.
shard_suffix (str): unique counter for this input file, used
when creating intermediate files
Expand Down
24 changes: 3 additions & 21 deletions src/hipscat_import/catalog/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import hipscat.io.write_metadata as io
import numpy as np
from dask.distributed import Client, as_completed
from dask.distributed import as_completed
from hipscat import pixel_math
from tqdm import tqdm

Expand Down Expand Up @@ -151,30 +151,12 @@ def _reduce_pixels(args, destination_pixel_map, client):
resume.set_reducing_done(args.tmp_path)


def _validate_args(args):
def run(args, client):
"""Run catalog creation pipeline."""
if not args:
raise ValueError("args is required and should be type ImportArguments")
if not isinstance(args, ImportArguments):
raise ValueError("args must be type ImportArguments")


def run(args):
"""Importer that creates a dask client from the arguments"""
_validate_args(args)

# pylint: disable=duplicate-code
with Client(
local_directory=args.dask_tmp,
n_workers=args.dask_n_workers,
threads_per_worker=args.dask_threads_per_worker,
) as client: # pragma: no cover
run_with_client(args, client)
# pylint: enable=duplicate-code


def run_with_client(args, client):
"""Importer, where the client context may out-live the runner"""
_validate_args(args)
raw_histogram = _map_pixels(args, client)

with tqdm(
Expand Down
8 changes: 2 additions & 6 deletions src/hipscat_import/index/run_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@
from hipscat_import.index.arguments import IndexArguments


def _validate_args(args):
def run(args):
"""Run index creation pipeline."""
if not args:
raise TypeError("args is required and should be type IndexArguments")
if not isinstance(args, IndexArguments):
raise TypeError("args must be type IndexArguments")


def run(args):
"""Importer, where the client context may out-live the runner"""
_validate_args(args)
rows_written = mr.create_index(args)

# All done - write out the metadata
Expand Down
3 changes: 1 addition & 2 deletions src/hipscat_import/margin_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""All modules for generating margin caches."""
from .margin_cache import (generate_margin_cache,
generate_margin_cache_with_client)
from .margin_cache import generate_margin_cache
from .margin_cache_arguments import MarginCacheArguments
48 changes: 12 additions & 36 deletions src/hipscat_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pandas as pd
from dask.distributed import Client, as_completed
from dask.distributed import as_completed
from hipscat import pixel_math
from hipscat.io import file_io, paths
from tqdm import tqdm
Expand All @@ -8,6 +8,7 @@

# pylint: disable=too-many-locals,too-many-arguments


def _find_partition_margin_pixel_pairs(stats, margin_order):
"""Creates a DataFrame filled with many-to-many connections between
the catalog partition pixels and the margin pixels at `margin_order`.
Expand All @@ -31,30 +32,28 @@ def _find_partition_margin_pixel_pairs(stats, margin_order):

margin_pairs_df = pd.DataFrame(
zip(norders, part_pix, margin_pix),
columns=["partition_order", "partition_pixel", "margin_pixel"]
columns=["partition_order", "partition_pixel", "margin_pixel"],
)
return margin_pairs_df


def _create_margin_directory(stats, output_path):
"""Creates directories for all the catalog partitions."""
for healpixel in stats:
order = healpixel.order
pix = healpixel.pixel

destination_dir = paths.pixel_directory(
output_path, order, pix
)
destination_dir = paths.pixel_directory(output_path, order, pix)
file_io.make_directory(destination_dir, exist_ok=True)


def _map_to_margin_shards(client, args, partition_pixels, margin_pairs):
"""Create all the jobs for mapping partition files into the margin cache."""
futures = []
mp_future = client.scatter(margin_pairs, broadcast=True)
for pix in partition_pixels:
partition_file = paths.pixel_catalog_file(
args.input_catalog_path,
pix.order,
pix.pixel
args.input_catalog_path, pix.order, pix.pixel
)
futures.append(
client.submit(
Expand All @@ -76,50 +75,27 @@ def _map_to_margin_shards(client, args, partition_pixels, margin_pairs):
):
...

def generate_margin_cache(args):

def generate_margin_cache(args, client):
"""Generate a margin cache for a given input catalog.
The input catalog must be in hipscat format.
This method will handle the creation of the `dask.distributed` client
based on the `dask_tmp`, `dask_n_workers`, and `dask_threads_per_worker`
values of the `MarginCacheArguments` object.
Args:
args (MarginCacheArguments): A valid `MarginCacheArguments` object.
"""
# pylint: disable=duplicate-code
with Client(
local_directory=args.dask_tmp,
n_workers=args.dask_n_workers,
threads_per_worker=args.dask_threads_per_worker,
) as client: # pragma: no cover
generate_margin_cache_with_client(
client,
args
)
# pylint: enable=duplicate-code

def generate_margin_cache_with_client(client, args):
"""Generate a margin cache for a given input catalog.
The input catalog must be in hipscat format.
Args:
client (dask.distributed.Client): A dask distributed client object.
args (MarginCacheArguments): A valid `MarginCacheArguments` object.
"""
# determine which order to generate margin pixels for
partition_stats = args.catalog.partition_info.get_healpix_pixels()

margin_pairs = _find_partition_margin_pixel_pairs(
partition_stats,
args.margin_order
partition_stats, args.margin_order
)

# arcsec to degree conversion
# TODO: remove this once hipscat uses arcsec for calculation
args.margin_threshold = args.margin_threshold / 3600.
args.margin_threshold = args.margin_threshold / 3600.0

_create_margin_directory(
partition_stats, args.catalog_path
)
_create_margin_directory(partition_stats, args.catalog_path)

_map_to_margin_shards(
client=client,
Expand Down
Loading

0 comments on commit f0cd4bf

Please sign in to comment.