Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev' into py312
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrakenhoff committed Jun 26, 2024
2 parents ffb6dca + 508b1f9 commit 2c17b75
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 137 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,5 @@ jobs:
if: ${{ github.repository == 'pastas/pastastore' && success() }}
uses: codacy/codacy-coverage-reporter-action@master
with:
project-token: ${{ secrets.CODACY_PROJECT_TOKEN }}
project-token: ${{ secrets.CODACY_API_TOKEN }}
coverage-reports: coverage.xml
63 changes: 56 additions & 7 deletions pastastore/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools
import json
import warnings

# import weakref
from abc import ABC, abstractmethod, abstractproperty
from collections.abc import Iterable
from itertools import chain
Expand All @@ -20,6 +22,34 @@
warnings.showwarning = _custom_warning


# def weak_lru(maxsize=128, typed=False):
# """LRU Cache decorator that keeps a weak reference to 'self'.

# From https://stackoverflow.com/a/68052994/10596229.

# Parameters
# ----------
# maxsize : int, optional
# maximum size of cache, by default 128
# typed : bool, optional
# whether to differentiate between types, by default False

# """

# def wrapper(func):
# @functools.lru_cache(maxsize, typed)
# def _func(_self, *args, **kwargs):
# return func(_self(), *args, **kwargs)

# @functools.wraps(func)
# def inner(self, *args, **kwargs):
# return _func(weakref.ref(self), *args, **kwargs)

# return inner

# return wrapper


class BaseConnector(ABC):
"""Base Connector class.
Expand Down Expand Up @@ -312,6 +342,9 @@ def _update_series(
self._validate_input_series(series)
series = self._set_series_name(series, name)
stored = self._get_series(libname, name, progressbar=False)
if self.conn_type == "pas" and (type(series) != type(stored)):
if isinstance(series, pd.DataFrame):
stored = stored.to_frame()
# get union of index
idx_union = stored.index.union(series.index)
# update series with new values
Expand Down Expand Up @@ -659,6 +692,16 @@ def del_models(self, names: Union[list, str]) -> None:
self._del_oseries_model_link(oname, n)
self._clear_cache("_modelnames_cache")

def del_model(self, names: Union[list, str]) -> None:
"""Delete model(s) from the database.
Parameters
----------
names : str or list of str
name(s) of the model to delete
"""
self.del_models(names=names)

def del_oseries(self, names: Union[list, str], remove_models: bool = False):
"""Delete oseries from the database.
Expand Down Expand Up @@ -1247,11 +1290,13 @@ def _meta_list_to_frame(metalist: list, names: list):
meta = pd.DataFrame(metalist)
elif len(metalist) == 0:
meta = pd.DataFrame()

meta.index = names
meta.index.name = "name"
return meta

def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
"""Internal method to parse dictionary describing pastas models.
"""Parse dictionary describing pastas models (internal method).
Parameters
----------
Expand All @@ -1276,7 +1321,7 @@ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
if name not in self.oseries.index:
msg = "oseries '{}' not present in library".format(name)
raise LookupError(msg)
mdict["oseries"]["series"] = self.get_oseries(name)
mdict["oseries"]["series"] = self.get_oseries(name).squeeze()
# update tmin/tmax from time series
if update_ts_settings:
mdict["oseries"]["settings"]["tmin"] = mdict["oseries"]["series"].index[
Expand All @@ -1296,7 +1341,7 @@ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
if "series" not in stress:
name = str(stress["name"])
if name in self.stresses.index:
stress["series"] = self.get_stresses(name)
stress["series"] = self.get_stresses(name).squeeze()
# update tmin/tmax from time series
if update_ts_settings:
stress["settings"]["tmin"] = stress["series"].index[
Expand All @@ -1311,7 +1356,7 @@ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
if "series" not in stress:
name = str(stress["name"])
if name in self.stresses.index:
stress["series"] = self.get_stresses(name)
stress["series"] = self.get_stresses(name).squeeze()
# update tmin/tmax from time series
if update_ts_settings:
stress["settings"]["tmin"] = stress["series"].index[
Expand All @@ -1327,7 +1372,7 @@ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
if "series" not in stress:
name = str(stress["name"])
if name in self.stresses.index:
stress["series"] = self.get_stresses(name)
stress["series"] = self.get_stresses(name).squeeze()
# update tmin/tmax from time series
if update_ts_settings:
stress["settings"]["tmin"] = stress["series"].index[0]
Expand Down Expand Up @@ -1717,23 +1762,27 @@ def _models_to_archive(self, archive, names=None, progressbar=True):
archive.writestr(f"models/{n}.pas", jsondict)

@staticmethod
def _series_from_json(fjson: str):
def _series_from_json(fjson: str, squeeze: bool = True):
"""Load time series from JSON.
Parameters
----------
fjson : str
path to file
squeeze : bool, optional
squeeze time series object to obtain pandas Series
Returns
-------
s : pd.DataFrame
DataFrame containing time series
"""
s = pd.read_json(fjson, orient="columns", precise_float=True)
s = pd.read_json(fjson, orient="columns", precise_float=True, dtype=False)
if not isinstance(s.index, pd.DatetimeIndex):
s.index = pd.to_datetime(s.index, unit="ms")
s = s.sort_index() # needed for some reason ...
if squeeze:
return s.squeeze()
return s

@staticmethod
Expand Down
12 changes: 12 additions & 0 deletions pastastore/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def __init__(self, name: str, connstr: str):
connstr : str
connection string (e.g. 'mongodb://localhost:27017/')
"""
warnings.warn(
"ArcticConnector is deprecated. Please use a different "
"connector, e.g. `pst.ArcticDBConnector`.",
DeprecationWarning,
stacklevel=1,
)
try:
import arctic
except ModuleNotFoundError as e:
Expand Down Expand Up @@ -392,6 +398,12 @@ def __init__(self, name: str, path: str):
path : str
path to the pystore directory
"""
warnings.warn(
"PystoreConnector is deprecated. Please use a different "
"connector, e.g. `pst.PasConnector`.",
DeprecationWarning,
stacklevel=1,
)
try:
import pystore
except ModuleNotFoundError as e:
Expand Down
35 changes: 28 additions & 7 deletions pastastore/store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -14,6 +14,7 @@
from pastastore.connectors import DictConnector
from pastastore.plotting import Maps, Plots
from pastastore.util import _custom_warning
from pastastore.version import PASTAS_GEQ_150
from pastastore.yaml_interface import PastastoreYAML

FrameorSeriesUnion = Union[pd.DataFrame, pd.Series]
Expand Down Expand Up @@ -384,7 +385,12 @@ def get_signatures(

return signatures_df

def get_tmin_tmax(self, libname, names=None, progressbar=False):
def get_tmin_tmax(
self,
libname: Literal["oseries", "stresses", "models"],
names: Union[str, List[str], None] = None,
progressbar: bool = False,
):
"""Get tmin and tmax for time series.
Parameters
Expand All @@ -410,12 +416,22 @@ def get_tmin_tmax(self, libname, names=None, progressbar=False):
)
desc = f"Get tmin/tmax {libname}"
for n in tqdm(names, desc=desc) if progressbar else names:
if libname == "oseries":
s = self.conn.get_oseries(n)
if libname == "models":
mld = self.conn.get_models(
n,
return_dict=True,
)
tmintmax.loc[n, "tmin"] = mld["settings"]["tmin"]
tmintmax.loc[n, "tmax"] = mld["settings"]["tmax"]
else:
s = self.conn.get_stresses(n)
tmintmax.loc[n, "tmin"] = s.first_valid_index()
tmintmax.loc[n, "tmax"] = s.last_valid_index()
s = (
self.conn.get_oseries(n)
if libname == "oseries"
else self.conn.get_stresses(n)
)
tmintmax.loc[n, "tmin"] = s.first_valid_index()
tmintmax.loc[n, "tmax"] = s.last_valid_index()

return tmintmax

def get_extent(self, libname, names=None, buffer=0.0):
Expand Down Expand Up @@ -558,6 +574,7 @@ def create_model(
name: str,
modelname: str = None,
add_recharge: bool = True,
add_ar_noisemodel: bool = False,
recharge_name: str = "recharge",
) -> ps.Model:
"""Create a pastas Model.
Expand All @@ -572,6 +589,8 @@ def create_model(
add recharge to the model by looking for the closest
precipitation and evaporation time series in the stresses
library, by default True
add_ar1_noisemodel : bool, optional
add AR(1) noise model to the model, by default False
recharge_name : str
name of the RechargeModel
Expand All @@ -598,6 +617,8 @@ def create_model(
ml = ps.Model(ts, name=modelname, metadata=meta)
if add_recharge:
self.add_recharge(ml, recharge_name=recharge_name)
if add_ar_noisemodel and PASTAS_GEQ_150:
ml.add_noisemodel(ps.ArNoiseModel())
return ml
else:
raise ValueError("Empty time series!")
Expand Down
7 changes: 6 additions & 1 deletion pastastore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ def frontiers_checks(
check4_gain: bool = True,
check5_parambounds: bool = False,
csv_dir: Optional[str] = None,
progressbar: bool = False,
) -> pd.DataFrame: # pragma: no cover
"""Check models in a PastaStore to see if they pass reliability criteria.
Expand Down Expand Up @@ -597,6 +598,8 @@ def frontiers_checks(
csv_dir : string, optional
directory to store CSV file with overview of checks for every
model, by default None which will not store results
progressbar : bool, optional
show progressbar, by default False
Returns
-------
Expand Down Expand Up @@ -629,7 +632,9 @@ def frontiers_checks(
else:
models = pstore.model_names

for mlnam in tqdm(models, desc="Running model diagnostics"):
for mlnam in (
tqdm(models, desc="Running model diagnostics") if progressbar else models
):
ml = pstore.get_models(mlnam)

if ml.parameters["optimal"].hasnans:
Expand Down
33 changes: 32 additions & 1 deletion pastastore/version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
# ruff: noqa: D100
from importlib import import_module, metadata
from platform import python_version

import pastas as ps
from packaging.version import parse as parse_version

PASTAS_VERSION = parse_version(ps.__version__)
PASTAS_LEQ_022 = PASTAS_VERSION <= parse_version("0.22.0")
PASTAS_GEQ_150 = PASTAS_VERSION >= parse_version("1.5.0")

__version__ = "1.5.0"


def show_versions(optional=False) -> None:
"""Print the version of dependencies.
Parameters
----------
optional : bool, optional
Print the version of optional dependencies, by default False
"""
msg = (
f"Python version : {python_version()}\n"
f"Pandas version : {metadata.version('pandas')}\n"
f"Matplotlib version : {metadata.version('matplotlib')}\n"
f"Pastas version : {metadata.version('pastas')}\n"
f"PyYAML version : {metadata.version('pyyaml')}\n"
)
if optional:
msg += "\nArcticDB version : "
try:
import_module("arcticdb")
msg += f"{metadata.version('arctidb')}"
except ImportError:
msg += "Not Installed"

__version__ = "1.4.0"
print(msg)
Loading

0 comments on commit 2c17b75

Please sign in to comment.