diff --git a/gliderpy/__init__.py b/gliderpy/__init__.py index 37971ac..8d2cbf0 100644 --- a/gliderpy/__init__.py +++ b/gliderpy/__init__.py @@ -4,3 +4,11 @@ from ._version import __version__ except ImportError: __version__ = "unknown" + +from .fetchers import GliderDataFetcher +from .plotting import plot_transect + +__all__ = [ + "GliderDataFetcher", + "plot_transect", +] diff --git a/gliderpy/fetchers.py b/gliderpy/fetchers.py index 2c5326d..009d915 100644 --- a/gliderpy/fetchers.py +++ b/gliderpy/fetchers.py @@ -16,6 +16,7 @@ ) OptionalBool = bool | None +OptionalDF = pd.DataFrame | None OptionalDict = dict | None OptionalList = list[str] | tuple[str] | None OptionalStr = str | None @@ -80,7 +81,7 @@ def __init__( ) self.fetcher.variables = server_vars[server] self.fetcher.dataset_id: OptionalStr = None - self.datasets: OptionalBool = None + self.datasets: OptionalDF = None def to_pandas(self: "GliderDataFetcher") -> pd.DataFrame: """Return data from the server as a pandas dataframe. @@ -97,10 +98,7 @@ def to_pandas(self: "GliderDataFetcher") -> pd.DataFrame: self.fetcher.dataset_id = None return glider_df else: - msg = ( - f"Must provide a {self.fetcher.dataset_id} or " - "`query` terms to download data." - ) + msg = "Must provide a dataset_id or query terms to download data." raise ValueError(msg) # Standardize variable names for the single dataset_id. @@ -145,7 +143,7 @@ def query( # noqa: PLR0913 "longitude>=": min_lon, "longitude<=": max_lon, } - if not self.datasets: + if self.datasets is None: url = self.fetcher.get_search_url( search_for="glider", response="csv", diff --git a/gliderpy/plotting.py b/gliderpy/plotting.py new file mode 100644 index 0000000..d1e5d33 --- /dev/null +++ b/gliderpy/plotting.py @@ -0,0 +1,57 @@ +"""Some convenience functions to help visualize glider data.""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +try: + import matplotlib.dates as mdates + import matplotlib.pyplot as plt +except ModuleNotFoundError: + warnings.warn( + "gliderpy requires matplotlib and cartopy for plotting.", + stacklevel=1, + ) + raise + + +if TYPE_CHECKING: + import pandas as pd + +from pandas_flavor import register_dataframe_method + + +@register_dataframe_method +def plot_transect( + df: pd.DataFrame, + var: str, + **kw: dict, +) -> tuple(plt.Figure, plt.Axes): + """Make a scatter plot of depth vs time coloured by a user defined + variable. + + :param var: variable to colour the scatter plot + :return: figure, axes + """ + cmap = kw.get("cmap", None) + + fig, ax = plt.subplots(figsize=(17, 2)) + cs = ax.scatter( + df.index, + df["pressure"], + s=15, + c=df[var], + marker="o", + edgecolor="none", + cmap=cmap, + ) + + ax.invert_yaxis() + xfmt = mdates.DateFormatter("%H:%Mh\n%d-%b") + ax.xaxis.set_major_formatter(xfmt) + + cbar = fig.colorbar(cs, orientation="vertical", extend="both") + cbar.ax.set_ylabel(var) + ax.set_ylabel("pressure") + return fig, ax