diff --git a/src/hipscat_import/catalog/file_readers.py b/src/hipscat_import/catalog/file_readers.py index af630e87..5a549737 100644 --- a/src/hipscat_import/catalog/file_readers.py +++ b/src/hipscat_import/catalog/file_readers.py @@ -291,18 +291,24 @@ class ParquetReader(InputReader): chunksize (int): number of rows of the file to process at once. For large files, this can prevent loading the entire file into memory at once. + column_names (list[str] or None): Names of columns to use from the input dataset. + If None, use all columns. kwargs: arguments to pass along to pyarrow.parquet.ParquetFile. See https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html """ - def __init__(self, chunksize=500_000, **kwargs): + def __init__(self, chunksize=500_000, column_names=None, **kwargs): self.chunksize = chunksize + self.column_names = column_names self.kwargs = kwargs def read(self, input_file, read_columns=None): self.regular_file_exists(input_file, **self.kwargs) + columns = read_columns or self.column_names parquet_file = pq.ParquetFile(input_file, **self.kwargs) - for smaller_table in parquet_file.iter_batches(batch_size=self.chunksize, use_pandas_metadata=True): + for smaller_table in parquet_file.iter_batches( + batch_size=self.chunksize, columns=columns, use_pandas_metadata=True + ): yield smaller_table.to_pandas() def provenance_info(self) -> dict: diff --git a/tests/hipscat_import/catalog/test_file_readers.py b/tests/hipscat_import/catalog/test_file_readers.py index d7fd2a14..c7f76176 100644 --- a/tests/hipscat_import/catalog/test_file_readers.py +++ b/tests/hipscat_import/catalog/test_file_readers.py @@ -255,6 +255,19 @@ def test_parquet_reader_provenance_info(tmp_path, basic_catalog_info): io.write_provenance_info(catalog_base_dir, basic_catalog_info, provenance_info) +def test_parquet_reader_columns(parquet_shards_shard_44_0): + """Verify we can read a subset of columns.""" + column_subset = ["id", "dec"] + + # test column_names class property + for frame in ParquetReader(column_names=column_subset).read(parquet_shards_shard_44_0): + assert set(frame.columns) == set(column_subset) + + # test read_columns kwarg + for frame in ParquetReader().read(parquet_shards_shard_44_0, read_columns=column_subset): + assert set(frame.columns) == set(column_subset) + + def test_read_fits(formats_fits): """Success case - fits file that exists being read as fits""" total_chunks = 0