From 71999d681337a23168ce1868a881684920291861 Mon Sep 17 00:00:00 2001 From: delucchi-cmu Date: Tue, 28 Nov 2023 12:55:19 -0500 Subject: [PATCH] Specify types for partitioning columns. --- src/hipscat_import/catalog/map_reduce.py | 6 +-- src/hipscat_import/index/map_reduce.py | 6 +-- .../hipscat_import/catalog/test_run_import.py | 50 ++++++++++++++++++- tests/hipscat_import/index/test_run_index.py | 11 ++-- 4 files changed, 63 insertions(+), 10 deletions(-) diff --git a/src/hipscat_import/catalog/map_reduce.py b/src/hipscat_import/catalog/map_reduce.py index fba940fb..a29ccc29 100644 --- a/src/hipscat_import/catalog/map_reduce.py +++ b/src/hipscat_import/catalog/map_reduce.py @@ -281,13 +281,13 @@ def reduce_pixel_shards( dataframe[dec_column].values, ) - dataframe["Norder"] = np.full(rows_written, fill_value=destination_pixel_order, dtype=np.int32) + dataframe["Norder"] = np.full(rows_written, fill_value=destination_pixel_order, dtype=np.uint8) dataframe["Dir"] = np.full( rows_written, fill_value=int(destination_pixel_number / 10_000) * 10_000, - dtype=np.int32, + dtype=np.uint32, ) - dataframe["Npix"] = np.full(rows_written, fill_value=destination_pixel_number, dtype=np.int32) + dataframe["Npix"] = np.full(rows_written, fill_value=destination_pixel_number, dtype=np.uint32) if add_hipscat_index: ## If we had a meaningful index before, preserve it as a column. diff --git a/src/hipscat_import/index/map_reduce.py b/src/hipscat_import/index/map_reduce.py index e446a2a9..06c6ad23 100644 --- a/src/hipscat_import/index/map_reduce.py +++ b/src/hipscat_import/index/map_reduce.py @@ -27,9 +27,9 @@ def create_index(args): if args.include_order_pixel: ## Take out the hive dictionary behavior. - data["Norder"] = data["Norder"].astype(np.int32) - data["Dir"] = data["Dir"].astype(np.int32) - data["Npix"] = data["Npix"].astype(np.int32) + data["Norder"] = data["Norder"].astype(np.uint8) + data["Dir"] = data["Dir"].astype(np.uint32) + data["Npix"] = data["Npix"].astype(np.uint32) data = data.reset_index() if not args.include_hipscat_index: data = data.drop(columns=[HIPSCAT_ID_COLUMN]) diff --git a/tests/hipscat_import/catalog/test_run_import.py b/tests/hipscat_import/catalog/test_run_import.py index 6f88a270..aae8467b 100644 --- a/tests/hipscat_import/catalog/test_run_import.py +++ b/tests/hipscat_import/catalog/test_run_import.py @@ -4,11 +4,16 @@ import shutil import hipscat.pixel_math as hist +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq import pytest from hipscat.catalog.catalog import Catalog import hipscat_import.catalog.run_import as runner from hipscat_import.catalog.arguments import ImportArguments +from hipscat_import.catalog.file_readers import CsvReader from hipscat_import.catalog.resume_plan import ResumePlan @@ -133,11 +138,19 @@ def test_dask_runner( assert_parquet_file_ids, tmp_path, ): - """Test basic execution.""" + """Test basic execution and the types of the written data.""" args = ImportArguments( output_artifact_name="small_sky_object_catalog", input_path=small_sky_parts_dir, input_format="csv", + file_reader=CsvReader( + type_map={ + "ra": np.float32, + "dec": np.float32, + "ra_error": np.float32, + "dec_error": np.float32, + } + ), output_path=tmp_path, dask_tmp=tmp_path, highest_healpix_order=1, @@ -161,6 +174,41 @@ def test_dask_runner( expected_ids = [*range(700, 831)] assert_parquet_file_ids(output_file, "id", expected_ids) + # Check that the schema is correct for leaf parquet and _metadata files + expected_parquet_schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("ra", pa.float32()), + pa.field("dec", pa.float32()), + pa.field("ra_error", pa.float32()), + pa.field("dec_error", pa.float32()), + pa.field("Norder", pa.uint8()), + pa.field("Dir", pa.uint32()), + pa.field("Npix", pa.uint32()), + pa.field("_hipscat_index", pa.uint64()), + ] + ) + schema = pq.read_metadata(output_file).schema.to_arrow_schema() + assert schema.equals(expected_parquet_schema, check_metadata=False) + schema = pq.read_metadata(os.path.join(args.catalog_path, "_metadata")).schema.to_arrow_schema() + assert schema.equals(expected_parquet_schema, check_metadata=False) + + # Check that, when re-loaded as a pandas dataframe, the appropriate numeric types are used. + data_frame = pd.read_parquet(output_file, engine="pyarrow") + expected_dtypes = pd.Series( + { + "id": np.int64, + "ra": np.float32, + "dec": np.float32, + "ra_error": np.float32, + "dec_error": np.float32, + "Norder": np.uint8, + "Dir": np.uint32, + "Npix": np.uint32, + } + ) + assert data_frame.dtypes.equals(expected_dtypes) + @pytest.mark.dask def test_dask_runner_stats_only(dask_client, small_sky_parts_dir, tmp_path): diff --git a/tests/hipscat_import/index/test_run_index.py b/tests/hipscat_import/index/test_run_index.py index aaee233a..b6749e17 100644 --- a/tests/hipscat_import/index/test_run_index.py +++ b/tests/hipscat_import/index/test_run_index.py @@ -49,12 +49,17 @@ def test_run_index( basic_index_parquet_schema = pa.schema( [ pa.field("_hipscat_index", pa.uint64()), - pa.field("Norder", pa.int32()), - pa.field("Dir", pa.int32()), - pa.field("Npix", pa.int32()), + pa.field("Norder", pa.uint8()), + pa.field("Dir", pa.uint32()), + pa.field("Npix", pa.uint32()), pa.field("id", pa.int64()), ] ) + + outfile = os.path.join(args.catalog_path, "index", "part.0.parquet") + schema = pq.read_metadata(outfile).schema.to_arrow_schema() + assert schema.equals(basic_index_parquet_schema, check_metadata=False) + schema = pq.read_metadata(os.path.join(args.catalog_path, "_metadata")).schema.to_arrow_schema() assert schema.equals(basic_index_parquet_schema, check_metadata=False)