diff --git a/glue/core/data_factories/fits.py b/glue/core/data_factories/fits.py index c21a89783..566112f53 100644 --- a/glue/core/data_factories/fits.py +++ b/glue/core/data_factories/fits.py @@ -2,7 +2,7 @@ import warnings from os.path import basename from collections import OrderedDict - +from astropy.coordinates import SkyCoord from glue.core.coordinates import coordinates_from_header, WCSCoordinates from glue.core.data import Component, Data from glue.config import data_factory, cli_parser @@ -151,9 +151,16 @@ def new_data(suffix=True): if column.ndim != 1: warnings.warn("Dropping column '{0}' since it is not 1-dimensional".format(column_name)) continue - component = Component.autotyped(column, units=column.unit) - data.add_component(component=component, - label=column_name) + if isinstance(column, SkyCoord): + for attribute_name in column.get_representation_component_names(): + values = getattr(column, attribute_name) + component = Component.autotyped(values, units=values.unit) + data.add_component(component=component, + label=f"{column_name}.{attribute_name}") + else: + component = Component.autotyped(column, units=column.unit) + data.add_component(component=component, + label=column_name) if close_hdulist: hdulist.close() diff --git a/glue/core/data_factories/tests/test_fits.py b/glue/core/data_factories/tests/test_fits.py index 03351822f..52d9bba4c 100644 --- a/glue/core/data_factories/tests/test_fits.py +++ b/glue/core/data_factories/tests/test_fits.py @@ -6,6 +6,10 @@ import numpy as np from numpy.testing import assert_array_equal +from astropy.table import Table +from astropy.coordinates import SkyCoord +from astropy import units as u + from glue.core import data_factories as df from glue.tests.helpers import requires_astropy, make_file @@ -229,3 +233,31 @@ def test_coordinate_component_units(): assert d.get_component(wid[2]).units == 'deg' assert wid[3].label == 'Right Ascension' assert d.get_component(wid[3]).units == 'deg' + + +@requires_astropy +def test_mixin_columns(tmp_path): + + t = Table() + t['coords'] = SkyCoord([1, 2, 3] * u.deg, [4, 5, 6] * u.deg, frame='galactic') + t['coords_with_distance'] = SkyCoord([1, 2, 3] * u.deg, [4, 5, 6] * u.deg, [7, 8, 9] * u.kpc, frame='fk5') + + t.write(tmp_path / 'table.fits', format='fits') + + d = fits_reader(tmp_path / 'table.fits')[0] + + assert len(d.main_components) == 6 + + assert_array_equal(d['coords.l'], [1, 2, 3]) + assert d.get_component('coords.l').units == 'deg' + assert_array_equal(d['coords.b'], [4, 5, 6]) + assert d.get_component('coords.b').units == 'deg' + assert_array_equal(d['coords.distance'], [1, 1, 1]) + assert d.get_component('coords.distance').units == '' + + assert_array_equal(d['coords_with_distance.ra'], [1, 2, 3]) + assert d.get_component('coords_with_distance.ra').units == 'deg' + assert_array_equal(d['coords_with_distance.dec'], [4, 5, 6]) + assert d.get_component('coords_with_distance.dec').units == 'deg' + assert_array_equal(d['coords_with_distance.distance'], [7, 8, 9]) + assert d.get_component('coords_with_distance.distance').units == 'kpc'