Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reading a file to pyarrow table in catalog import #415

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/hats_import/catalog/file_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc

import pandas as pd
import pyarrow
import pyarrow as pa
import pyarrow.dataset
import pyarrow.parquet as pq
from astropy.io import ascii as ascii_reader
Expand Down Expand Up @@ -356,6 +356,34 @@ def read(self, input_file, read_columns=None):
yield smaller_table.to_pandas()


class ParquetPyarrowReader(InputReader):
"""Parquet reader for the most common Parquet reading arguments.

Attributes:
chunksize (int): number of rows of the file to process at once.
For large files, this can prevent loading the entire file
into memory at once.
column_names (list[str] or None): Names of columns to use from the input dataset.
If None, use all columns.
kwargs: arguments to pass along to pyarrow.parquet.ParquetFile.
See https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html
"""

def __init__(self, chunksize=500_000, column_names=None, **kwargs):
self.chunksize = chunksize
self.column_names = column_names
self.kwargs = kwargs

def read(self, input_file, read_columns=None):
self.regular_file_exists(input_file, **self.kwargs)
columns = read_columns or self.column_names
parquet_file = pq.ParquetFile(input_file, **self.kwargs)
for smaller_table in parquet_file.iter_batches(batch_size=self.chunksize, columns=columns):
table = pa.Table.from_batches([smaller_table])
table = table.replace_schema_metadata()
yield table


class IndexedParquetReader(InputReader):
"""Reads an index file, containing paths to parquet files to be read and batched

Expand Down
115 changes: 53 additions & 62 deletions src/hats_import/catalog/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import hats.pixel_math.healpix_shim as hp
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from hats import pixel_math
Expand Down Expand Up @@ -50,21 +51,30 @@ def _iterate_input_file(

for chunk_number, data in enumerate(file_reader.read(input_file, read_columns=read_columns)):
if use_healpix_29:
if data.index.name == SPATIAL_INDEX_COLUMN:
if isinstance(data, pd.DataFrame) and data.index.name == SPATIAL_INDEX_COLUMN:
mapped_pixels = spatial_index_to_healpix(data.index, target_order=highest_order)
else:
mapped_pixels = spatial_index_to_healpix(
data[SPATIAL_INDEX_COLUMN], target_order=highest_order
)
else:
# Set up the pixel data
mapped_pixels = hp.ang2pix(
2**highest_order,
data[ra_column].to_numpy(copy=False, dtype=float),
data[dec_column].to_numpy(copy=False, dtype=float),
lonlat=True,
nest=True,
)
if isinstance(data, pd.DataFrame):
mapped_pixels = hp.ang2pix(
2**highest_order,
data[ra_column].to_numpy(copy=False, dtype=float),
data[dec_column].to_numpy(copy=False, dtype=float),
lonlat=True,
nest=True,
)
else:
mapped_pixels = hp.ang2pix(
2**highest_order,
data[ra_column].to_numpy(),
data[dec_column].to_numpy(),
lonlat=True,
nest=True,
)
yield chunk_number, data, mapped_pixels


Expand Down Expand Up @@ -168,17 +178,20 @@ def split_pixels(
unique_pixels, unique_inverse = np.unique(aligned_pixels, return_inverse=True)

for unique_index, [order, pixel, _] in enumerate(unique_pixels):
filtered_data = data.iloc[unique_inverse == unique_index]

pixel_dir = get_pixel_cache_directory(cache_shard_path, HealpixPixel(order, pixel))
file_io.make_directory(pixel_dir, exist_ok=True)
output_file = file_io.append_paths_to_pointer(
pixel_dir, f"shard_{splitting_key}_{chunk_number}.parquet"
)
if _has_named_index(filtered_data):
filtered_data.to_parquet(output_file.path, index=True, filesystem=output_file.fs)
if isinstance(data, pd.DataFrame):
filtered_data = data.iloc[unique_inverse == unique_index]
if _has_named_index(filtered_data):
filtered_data = filtered_data.reset_index()
filtered_data = pa.Table.from_pandas(filtered_data, preserve_index=False)
else:
filtered_data.to_parquet(output_file.path, index=False, filesystem=output_file.fs)
filtered_data = data.filter(unique_inverse == unique_index)

pq.write_table(filtered_data, output_file.path, filesystem=output_file.fs)
del filtered_data

ResumePlan.splitting_key_done(tmp_path=resume_path, splitting_key=splitting_key)
Expand Down Expand Up @@ -258,15 +271,10 @@ def reduce_pixel_shards(
if use_schema_file:
schema = file_io.read_parquet_metadata(use_schema_file).schema.to_arrow_schema()

tables = []
healpix_pixel = HealpixPixel(destination_pixel_order, destination_pixel_number)
pixel_dir = get_pixel_cache_directory(cache_shard_path, healpix_pixel)

if schema:
tables.append(pq.read_table(pixel_dir, schema=schema))
else:
tables.append(pq.read_table(pixel_dir))

merged_table = pa.concat_tables(tables)
merged_table = pq.read_table(pixel_dir, schema=schema)

rows_written = len(merged_table)

Expand All @@ -277,38 +285,36 @@ def reduce_pixel_shards(
f" Expected {destination_pixel_size}, wrote {rows_written}"
)

dataframe = merged_table.to_pandas()
if sort_columns:
dataframe = dataframe.sort_values(sort_columns.split(","), kind="stable")
split_columns = sort_columns.split(",")
if len(split_columns) > 1:
merged_table = merged_table.sort_by([(col_name, "ascending") for col_name in split_columns])
else:
merged_table = merged_table.sort_by(sort_columns)
if add_healpix_29:
## If we had a meaningful index before, preserve it as a column.
if _has_named_index(dataframe):
dataframe = dataframe.reset_index()

dataframe[SPATIAL_INDEX_COLUMN] = pixel_math.compute_spatial_index(
dataframe[ra_column].values,
dataframe[dec_column].values,
)
dataframe = dataframe.set_index(SPATIAL_INDEX_COLUMN).sort_index(kind="stable")

# Adjust the schema to make sure that the _healpix_29 will
# be saved as a uint64
merged_table = merged_table.add_column(
0,
SPATIAL_INDEX_COLUMN,
[
pixel_math.compute_spatial_index(
merged_table[ra_column].to_numpy(),
merged_table[dec_column].to_numpy(),
)
],
).sort_by(SPATIAL_INDEX_COLUMN)
elif use_healpix_29:
if dataframe.index.name != SPATIAL_INDEX_COLUMN:
dataframe = dataframe.set_index(SPATIAL_INDEX_COLUMN)
dataframe = dataframe.sort_index(kind="stable")
merged_table = merged_table.sort_by(SPATIAL_INDEX_COLUMN)

dataframe["Norder"] = np.full(rows_written, fill_value=healpix_pixel.order, dtype=np.uint8)
dataframe["Dir"] = np.full(rows_written, fill_value=healpix_pixel.dir, dtype=np.uint64)
dataframe["Npix"] = np.full(rows_written, fill_value=healpix_pixel.pixel, dtype=np.uint64)

if schema:
schema = _modify_arrow_schema(schema, add_healpix_29)
dataframe.to_parquet(destination_file.path, schema=schema, filesystem=destination_file.fs)
else:
dataframe.to_parquet(destination_file.path, filesystem=destination_file.fs)
merged_table = (
merged_table.append_column(
"Norder", [np.full(rows_written, fill_value=healpix_pixel.order, dtype=np.uint8)]
)
.append_column("Dir", [np.full(rows_written, fill_value=healpix_pixel.dir, dtype=np.uint64)])
.append_column("Npix", [np.full(rows_written, fill_value=healpix_pixel.pixel, dtype=np.uint64)])
)

del dataframe, merged_table, tables
pq.write_table(merged_table, destination_file.path, filesystem=destination_file.fs)
del merged_table

if delete_input_files:
pixel_dir = get_pixel_cache_directory(cache_shard_path, healpix_pixel)
Expand All @@ -322,18 +328,3 @@ def reduce_pixel_shards(
exception,
)
raise exception


def _modify_arrow_schema(schema, add_healpix_29):
if add_healpix_29:
pandas_index_column = schema.get_field_index("__index_level_0__")
if pandas_index_column != -1:
schema = schema.remove(pandas_index_column)
schema = schema.insert(0, pa.field(SPATIAL_INDEX_COLUMN, pa.int64()))
schema = (
schema.append(pa.field("Norder", pa.uint8()))
.append(pa.field("Dir", pa.uint64()))
.append(pa.field("Npix", pa.uint64()))
)

return schema
7 changes: 5 additions & 2 deletions src/hats_import/index/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
schema=schema,
)

data = data.reset_index()
if not include_healpix_29:
if data.index.name == SPATIAL_INDEX_COLUMN:
data = data.reset_index()

Check warning on line 21 in src/hats_import/index/map_reduce.py

View check run for this annotation

Codecov / codecov/patch

src/hats_import/index/map_reduce.py#L21

Added line #L21 was not covered by tests
if not include_healpix_29 and SPATIAL_INDEX_COLUMN in data.columns:
data = data.drop(columns=[SPATIAL_INDEX_COLUMN])

if drop_duplicates:
Expand All @@ -32,6 +33,8 @@
include_columns = [args.indexing_column]
if args.extra_columns:
include_columns.extend(args.extra_columns)
if args.include_healpix_29:
include_columns.append(SPATIAL_INDEX_COLUMN)
if args.include_order_pixel:
include_columns.extend(["Norder", "Dir", "Npix"])

Expand Down
Binary file not shown.
Binary file modified tests/data/small_sky_object_catalog/dataset/_common_metadata
Binary file not shown.
Binary file modified tests/data/small_sky_object_catalog/dataset/_metadata
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/data/small_sky_object_catalog/properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ hats_col_dec=dec
hats_max_rows=1000000
hats_order=0
moc_sky_fraction=0.08333
hats_builder=hats-import v0.3.6.dev26+g40366b4
hats_creation_date=2024-10-11T15\:02UTC
hats_estsize=74
hats_builder=hats-import v0.4.1.dev2+gaeb92ae
hats_creation_date=2024-10-21T13\:22UTC
hats_estsize=70
hats_release_date=2024-09-18
hats_version=v0.1
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/data/small_sky_source_catalog/dataset/_common_metadata
Binary file not shown.
Binary file modified tests/data/small_sky_source_catalog/dataset/_metadata
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/data/small_sky_source_catalog/properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ hats_col_dec=source_dec
hats_max_rows=3000
hats_order=2
moc_sky_fraction=0.16667
hats_builder=hats-import v0.3.6.dev26+g40366b4
hats_creation_date=2024-10-11T15\:02UTC
hats_estsize=1105
hats_builder=hats-import v0.4.1.dev2+gaeb92ae
hats_creation_date=2024-10-21T13\:22UTC
hats_estsize=1083
hats_release_date=2024-09-18
hats_version=v0.1
3 changes: 1 addition & 2 deletions tests/hats_import/catalog/test_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,9 @@ def test_reduce_healpix_29(parquet_shards_dir, assert_parquet_file_ids, tmp_path
expected_ids = [*range(700, 831)]
assert_parquet_file_ids(output_file, "id", expected_ids)
data_frame = pd.read_parquet(output_file, engine="pyarrow")
assert data_frame.index.name == "_healpix_29"
npt.assert_array_equal(
data_frame.columns,
["id", "ra", "dec", "ra_error", "dec_error", "Norder", "Dir", "Npix"],
["_healpix_29", "id", "ra", "dec", "ra_error", "dec_error", "Norder", "Dir", "Npix"],
)

mr.reduce_pixel_shards(
Expand Down
3 changes: 2 additions & 1 deletion tests/hats_import/catalog/test_run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def test_dask_runner(
# Check that the schema is correct for leaf parquet and _metadata files
expected_parquet_schema = pa.schema(
[
pa.field("_healpix_29", pa.int64()),
pa.field("id", pa.int64()),
pa.field("ra", pa.float32()),
pa.field("dec", pa.float32()),
Expand All @@ -286,7 +287,6 @@ def test_dask_runner(
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint64()),
pa.field("Npix", pa.uint64()),
pa.field("_healpix_29", pa.int64()),
]
)
schema = pq.read_metadata(output_file).schema.to_arrow_schema()
Expand All @@ -298,6 +298,7 @@ def test_dask_runner(
data_frame = pd.read_parquet(output_file, engine="pyarrow")
expected_dtypes = pd.Series(
{
"_healpix_29": np.int64,
"id": np.int64,
"ra": np.float32,
"dec": np.float32,
Expand Down
Loading