diff --git a/glue_wwt/viewer/data_viewer.py b/glue_wwt/viewer/data_viewer.py index 30ca01c5..fd2fb6ea 100644 --- a/glue_wwt/viewer/data_viewer.py +++ b/glue_wwt/viewer/data_viewer.py @@ -7,6 +7,7 @@ from glue.core.coordinates import WCSCoordinates from glue.logger import logger from pywwt import ViewerNotAvailableError +from pywwt.layers import guess_lon_lat_columns from .image_layer import WWTImageLayerArtist from .table_layer import WWTTableLayerArtist @@ -127,3 +128,13 @@ def __setgluestate__(cls, rec, context): camera_kwargs["roll"] = roll * u.deg viewer._wwt.center_on_coordinates(SkyCoord(ra, dec, unit=u.deg), **camera_kwargs) return viewer + + def add_data(self, data): + add = super().add_data(data) + if add and len(self.state.layers) == 1: + colnames = [c.label for c in data.components] + lon, lat = guess_lon_lat_columns(colnames) + if lon is not None and lat is not None: + self.state.lon_att = data.id[lon] + self.state.lat_att = data.id[lat] + return add diff --git a/glue_wwt/viewer/tests/test_wwt_widget.py b/glue_wwt/viewer/tests/test_wwt_widget.py index 14c36caf..ffc189da 100644 --- a/glue_wwt/viewer/tests/test_wwt_widget.py +++ b/glue_wwt/viewer/tests/test_wwt_widget.py @@ -9,7 +9,7 @@ from qtpy import compat -from glue.core import Data, message +from glue.core import ComponentLink, Data, message from glue.core.tests.test_state import clone from glue_qt.app import GlueApplication @@ -30,13 +30,16 @@ class TestWWTDataViewer(object): def setup_method(self, method): self.d = Data(x=[1, 2, 3], y=[2, 3, 4], z=[4, 5, 6]) + self.ra_dec_data = Data(ra=[-10, 0, 10], dec=[0, 10, 20]) self.bad_data_short = Data(x=[-100, 100], y=[-10, 10]) self.bad_data_long = Data(x=[-100, -90, -80, 80, 90, 100], y=[-10, -7, -3, 3, 7, 10]) self.application = GlueApplication() self.dc = self.application.data_collection self.dc.append(self.d) + self.dc.append(self.ra_dec_data) self.dc.append(self.bad_data_short) self.dc.append(self.bad_data_long) + self.dc.add_link(ComponentLink([self.d.id['x']], self.d.id['y'])) self.hub = self.dc.hub self.session = self.application.session self.viewer = self.application.new_data_viewer(WWTQtViewerBlocking) @@ -201,6 +204,28 @@ def test_skycoord_exception_message_long(self): disabled_message = create_disabled_message(disabled_reason) assert layer.disabled_message == disabled_message + def test_guess_ra_dec_columns(self): + + # If the first `Data` that we add has columns that should lend + # themselves towards guessable RA/Dec components, check that + # the viewer state attributes get set correctly + self.viewer.add_data(self.ra_dec_data) + assert self.viewer.state.lon_att is self.ra_dec_data.id['ra'] + assert self.viewer.state.lat_att is self.ra_dec_data.id['dec'] + + def test_no_guess_ra_dec_columns(self): + + # Check that we correctly DON'T guess RA/Dec columns + self.viewer.add_data(self.d) + assert self.viewer.state.lon_att is self.d.id['x'] + assert self.viewer.state.lat_att is self.d.id['y'] + + # Check that if we add a second `Data` with valid RA/Dec behavior, + # we don't override the current state + self.viewer.add_data(self.ra_dec_data) + assert self.viewer.state.lon_att is self.d.id['x'] + assert self.viewer.state.lat_att is self.d.id['y'] + # TODO: determine if the following test is the desired behavior # def test_subsets_not_live_added_if_data_not_present(self): # self.register()