diff --git a/src/hipscat_import/catalog/file_readers.py b/src/hipscat_import/catalog/file_readers.py index e0202dd..f94c2fc 100644 --- a/src/hipscat_import/catalog/file_readers.py +++ b/src/hipscat_import/catalog/file_readers.py @@ -306,7 +306,7 @@ def read(self, input_file, read_columns=None): table = Table.read(input_file, memmap=True, **self.kwargs) if read_columns: table.keep_columns(read_columns) - if self.column_names: + elif self.column_names: table.keep_columns(self.column_names) elif self.skip_column_names: table.remove_columns(self.skip_column_names) diff --git a/tests/hipscat_import/catalog/test_file_readers.py b/tests/hipscat_import/catalog/test_file_readers.py index 5438f75..8ceedb3 100644 --- a/tests/hipscat_import/catalog/test_file_readers.py +++ b/tests/hipscat_import/catalog/test_file_readers.py @@ -360,9 +360,17 @@ def test_read_fits_columns(formats_fits): frame = next(FitsReader(column_names=["id", "ra", "dec"]).read(formats_fits)) assert list(frame.columns) == ["id", "ra", "dec"] + frame = next(FitsReader(column_names=["id", "ra", "dec"]).read(formats_fits, read_columns=["ra", "dec"])) + assert list(frame.columns) == ["ra", "dec"] + frame = next(FitsReader(skip_column_names=["ra_error", "dec_error"]).read(formats_fits)) assert list(frame.columns) == ["id", "ra", "dec", "test_id"] + frame = next( + FitsReader(skip_column_names=["ra_error", "dec_error"]).read(formats_fits, read_columns=["ra", "dec"]) + ) + assert list(frame.columns) == ["ra", "dec"] + def test_fits_reader_provenance_info(tmp_path, basic_catalog_info): """Test that we get some provenance info and it is parseable into JSON."""