diff --git a/examples/introduction.ipynb b/examples/introduction.ipynb index 7a2da43..5746efc 100644 --- a/examples/introduction.ipynb +++ b/examples/introduction.ipynb @@ -246,7 +246,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.10.5" } }, "nbformat": 4, diff --git a/examples/vector.ipynb b/examples/vector.ipynb index 6c9e135..21e7a22 100644 --- a/examples/vector.ipynb +++ b/examples/vector.ipynb @@ -10,7 +10,7 @@ "import geopandas as gpd\n", "from ipyleaflet import LayersControl, Map, WidgetControl, basemaps\n", "from ipywidgets import FloatSlider\n", - "from xarray_leaflet import LeafletMap\n", + "import xarray_leaflet\n", "import matplotlib.pyplot as plt" ] }, @@ -54,8 +54,7 @@ "metadata": {}, "outputs": [], "source": [ - "lm = LeafletMap(df=df)\n", - "l = lm.plot(m, measurement=\"mask\", dynamic=True, colormap=plt.cm.inferno)" + "l = df.leaflet.plot(m, measurement=\"mask\", colormap=plt.cm.inferno)" ] }, { diff --git a/setup.cfg b/setup.cfg index e588a36..825e952 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ install_requires = mercantile >=1 ipyspin >=0.1.1 ipyurl >=0.1.2 - geocube + geocube <1.0.0 pygeos >=0.12,<1.0.0 zarr >=2.0.0,<3.0.0 diff --git a/ui-tests/notebooks/test_vector.ipynb b/ui-tests/notebooks/test_vector.ipynb index 63f4a74..ea33e5a 100644 --- a/ui-tests/notebooks/test_vector.ipynb +++ b/ui-tests/notebooks/test_vector.ipynb @@ -9,7 +9,7 @@ "import geopandas\n", "import matplotlib.pyplot as plt\n", "from ipyleaflet import Map\n", - "from xarray_leaflet import LeafletMap" + "import xarray_leaflet" ] }, { @@ -39,7 +39,7 @@ "metadata": {}, "outputs": [], "source": [ - "l = LeafletMap(df=df).plot(m, fit_bounds=False, colormap=plt.cm.inferno, measurement=\"mask\")" + "l = df.leaflet.plot(m, fit_bounds=False, colormap=plt.cm.inferno, measurement=\"mask\")" ] } ], diff --git a/xarray_leaflet/__init__.py b/xarray_leaflet/__init__.py index 4c47a9f..9e8df7d 100644 --- a/xarray_leaflet/__init__.py +++ b/xarray_leaflet/__init__.py @@ -3,6 +3,7 @@ from .server_extension import _jupyter_nbextension_paths # noqa from .server_extension import _jupyter_server_extension_paths # noqa from .server_extension import _load_jupyter_server_extension -from .xarray_leaflet import LeafletMap # noqa +from .xarray_leaflet import DataArrayLeaflet # noqa +from .xarray_leaflet import GeoDataFrameLeaflet # noqa load_jupyter_server_extension = _load_jupyter_server_extension diff --git a/xarray_leaflet/xarray_leaflet.py b/xarray_leaflet/xarray_leaflet.py index d2e4862..d7b2b55 100644 --- a/xarray_leaflet/xarray_leaflet.py +++ b/xarray_leaflet/xarray_leaflet.py @@ -7,6 +7,7 @@ import matplotlib as mpl import mercantile import numpy as np +import pandas as pd import xarray as xr from ipyleaflet import DrawControl, LocalTileLayer, WidgetControl from ipyspin import Spinner @@ -30,9 +31,9 @@ from .vector import Zvect -@xr.register_dataarray_accessor("leaflet") -class LeafletMap(HasTraits): - """A xarray.DataArray extension for tiled map plotting, based on (ipy)leaflet.""" +class Leaflet(HasTraits): + + is_vector: bool map_ready = Bool(False) @@ -40,11 +41,6 @@ class LeafletMap(HasTraits): def _map_ready_changed(self, change): self._start() - def __init__(self, da: xr.DataArray = None, df: gpd.GeoDataFrame = None): - self._da = da - self._df = df - self._da_selected = None - def plot( self, m, @@ -127,14 +123,7 @@ def plot( self.layer = LocalTileLayer() - source_nb = sum([0 if i is None else 1 for i in (self._da, self._df)]) - if source_nb == 0: - raise RuntimeError("No DataArray or GeoDataFrame provided") - - if source_nb > 1: - raise RuntimeError("Only one of DataArray or GeoDataFrame must be provided") - - if self._df is not None: + if self.is_vector: # source is a GeoDataFrame (vector) if measurement is None: raise RuntimeError("You must provide a 'measurement'.") @@ -146,7 +135,7 @@ def plot( ) if colormap is None: colormap = plt.cm.viridis - elif self._da is not None: + else: # source is a DataArray (raster) if "proj4def" in m.crs: # it's a custom projection @@ -252,7 +241,7 @@ def plot( else: self.base_url = get_base_url(self.m.window_url) - if fit_bounds and self._da is not None: + if fit_bounds and not self.is_vector: asyncio.ensure_future(self.async_fit_bounds()) else: asyncio.ensure_future(self.async_wait_for_bounds()) @@ -302,7 +291,7 @@ def _get_selection(self, *args, **kwargs): def _start(self): self.m.add_control(self.spinner_control) - if self._da is not None: + if not self.is_vector: self._da, self.transform0_args = get_transform(self.transform0(self._da)) else: self.layer.name = self.measurement @@ -318,14 +307,16 @@ def _start(self): self.layer.path = self.url self.m.remove_control(self.spinner_control) - if self._da is not None: + if not self.is_vector: get_tiles = self._get_raster_tiles - else: + elif self._df is not None: get_tiles = self._get_vector_tiles + else: + raise RuntimeError("No DataArray or GeoDataFrame provided.") get_tiles() self.m.observe(get_tiles, names="pixel_bounds") if not self.dynamic: - if self._da is not None: + if not self.is_vector: self._show_colorbar(self._da_notransform) self.m.add_layer(self.layer) @@ -603,3 +594,22 @@ async def async_fit_bounds(self): if self.base_url is None: self.base_url = (await self.url_widget.get_url()).rstrip("/") self.map_ready = True + + +@xr.register_dataarray_accessor("leaflet") +class DataArrayLeaflet(Leaflet): + """A DataArraye extension for tiled map plotting, based on (ipy)leaflet.""" + + def __init__(self, da: xr.DataArray = None): + self._da = da + self._da_selected = None + self.is_vector = False + + +@pd.api.extensions.register_dataframe_accessor("leaflet") +class GeoDataFrameLeaflet(Leaflet): + """A GeoDataFrame extension for tiled map plotting, based on (ipy)leaflet.""" + + def __init__(self, df: gpd.GeoDataFrame = None): + self._df = df + self.is_vector = True