Skip to content

Commit

Permalink
Specify types for partitioning columns.
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu committed Nov 28, 2023
1 parent 9f44b40 commit 71999d6
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/hipscat_import/catalog/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,13 @@ def reduce_pixel_shards(
dataframe[dec_column].values,
)

dataframe["Norder"] = np.full(rows_written, fill_value=destination_pixel_order, dtype=np.int32)
dataframe["Norder"] = np.full(rows_written, fill_value=destination_pixel_order, dtype=np.uint8)
dataframe["Dir"] = np.full(
rows_written,
fill_value=int(destination_pixel_number / 10_000) * 10_000,
dtype=np.int32,
dtype=np.uint32,
)
dataframe["Npix"] = np.full(rows_written, fill_value=destination_pixel_number, dtype=np.int32)
dataframe["Npix"] = np.full(rows_written, fill_value=destination_pixel_number, dtype=np.uint32)

if add_hipscat_index:
## If we had a meaningful index before, preserve it as a column.
Expand Down
6 changes: 3 additions & 3 deletions src/hipscat_import/index/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def create_index(args):

if args.include_order_pixel:
## Take out the hive dictionary behavior.
data["Norder"] = data["Norder"].astype(np.int32)
data["Dir"] = data["Dir"].astype(np.int32)
data["Npix"] = data["Npix"].astype(np.int32)
data["Norder"] = data["Norder"].astype(np.uint8)
data["Dir"] = data["Dir"].astype(np.uint32)
data["Npix"] = data["Npix"].astype(np.uint32)
data = data.reset_index()
if not args.include_hipscat_index:
data = data.drop(columns=[HIPSCAT_ID_COLUMN])
Expand Down
50 changes: 49 additions & 1 deletion tests/hipscat_import/catalog/test_run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import shutil

import hipscat.pixel_math as hist
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from hipscat.catalog.catalog import Catalog

import hipscat_import.catalog.run_import as runner
from hipscat_import.catalog.arguments import ImportArguments
from hipscat_import.catalog.file_readers import CsvReader
from hipscat_import.catalog.resume_plan import ResumePlan


Expand Down Expand Up @@ -133,11 +138,19 @@ def test_dask_runner(
assert_parquet_file_ids,
tmp_path,
):
"""Test basic execution."""
"""Test basic execution and the types of the written data."""
args = ImportArguments(
output_artifact_name="small_sky_object_catalog",
input_path=small_sky_parts_dir,
input_format="csv",
file_reader=CsvReader(
type_map={
"ra": np.float32,
"dec": np.float32,
"ra_error": np.float32,
"dec_error": np.float32,
}
),
output_path=tmp_path,
dask_tmp=tmp_path,
highest_healpix_order=1,
Expand All @@ -161,6 +174,41 @@ def test_dask_runner(
expected_ids = [*range(700, 831)]
assert_parquet_file_ids(output_file, "id", expected_ids)

# Check that the schema is correct for leaf parquet and _metadata files
expected_parquet_schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("ra", pa.float32()),
pa.field("dec", pa.float32()),
pa.field("ra_error", pa.float32()),
pa.field("dec_error", pa.float32()),
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint32()),
pa.field("Npix", pa.uint32()),
pa.field("_hipscat_index", pa.uint64()),
]
)
schema = pq.read_metadata(output_file).schema.to_arrow_schema()
assert schema.equals(expected_parquet_schema, check_metadata=False)
schema = pq.read_metadata(os.path.join(args.catalog_path, "_metadata")).schema.to_arrow_schema()
assert schema.equals(expected_parquet_schema, check_metadata=False)

# Check that, when re-loaded as a pandas dataframe, the appropriate numeric types are used.
data_frame = pd.read_parquet(output_file, engine="pyarrow")
expected_dtypes = pd.Series(
{
"id": np.int64,
"ra": np.float32,
"dec": np.float32,
"ra_error": np.float32,
"dec_error": np.float32,
"Norder": np.uint8,
"Dir": np.uint32,
"Npix": np.uint32,
}
)
assert data_frame.dtypes.equals(expected_dtypes)


@pytest.mark.dask
def test_dask_runner_stats_only(dask_client, small_sky_parts_dir, tmp_path):
Expand Down
11 changes: 8 additions & 3 deletions tests/hipscat_import/index/test_run_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,17 @@ def test_run_index(
basic_index_parquet_schema = pa.schema(
[
pa.field("_hipscat_index", pa.uint64()),
pa.field("Norder", pa.int32()),
pa.field("Dir", pa.int32()),
pa.field("Npix", pa.int32()),
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint32()),
pa.field("Npix", pa.uint32()),
pa.field("id", pa.int64()),
]
)

outfile = os.path.join(args.catalog_path, "index", "part.0.parquet")
schema = pq.read_metadata(outfile).schema.to_arrow_schema()
assert schema.equals(basic_index_parquet_schema, check_metadata=False)

schema = pq.read_metadata(os.path.join(args.catalog_path, "_metadata")).schema.to_arrow_schema()
assert schema.equals(basic_index_parquet_schema, check_metadata=False)

Expand Down

0 comments on commit 71999d6

Please sign in to comment.