diff --git a/chainladder/core/base.py b/chainladder/core/base.py index be85795b..5b25750b 100644 --- a/chainladder/core/base.py +++ b/chainladder/core/base.py @@ -274,12 +274,18 @@ def _get_grain(dates, trailing=False, kind="origin"): diffs = np.diff(np.sort(months)) if len(dates.unique()) == 1: grain = ( - "Y" if float(".".join(pd.__version__.split(".")[:-1])) >= 2.2 else "A" + "A" + if version.Version(pd.__version__) >= version.Version("2.2.0") + else "Y" ) + elif len(months) == 1: grain = ( - "Y" if float(".".join(pd.__version__.split(".")[:-1])) >= 2.2 else "A" + "A" + if version.Version(pd.__version__) >= version.Version("2.2.0") + else "Y" ) + elif np.all(diffs == 6): grain = "2Q" elif np.all(diffs == 3): diff --git a/chainladder/core/correlation.py b/chainladder/core/correlation.py index a46e9279..2de06522 100644 --- a/chainladder/core/correlation.py +++ b/chainladder/core/correlation.py @@ -189,19 +189,14 @@ def pZlower(z: int, n: int, p: float = 0.5) -> float: lr = triangle.link_ratio # Rank link ratios for each column - m1 = xp.apply_along_axis( - func1d=rankdata, - axis=2, - arr=lr.values, - ) * (lr.values * 0 + 1) - - med = xp.nanmedian( - a=m1, - axis=2, - keepdims=True, + m1 = xp.apply_along_axis(func1d=rankdata, axis=2, arr=lr.values) * ( + lr.values * 0 + 1 ) # print("med:\n", med) + med = xp.nanmedian(a=m1, axis=2, keepdims=True) + # print("med:\n", med) + m1large = (xp.nan_to_num(m1) > med) + (lr.values * 0) m1small = (xp.nan_to_num(m1) < med) + (lr.values * 0) m2large = triangle.link_ratio diff --git a/chainladder/core/triangle.py b/chainladder/core/triangle.py index 86fa5a03..b9806c7b 100644 --- a/chainladder/core/triangle.py +++ b/chainladder/core/triangle.py @@ -6,6 +6,7 @@ import numpy as np import copy import warnings +from packaging import version from chainladder.core.base import TriangleBase from chainladder.utils.sparse import sp from chainladder.core.slice import VirtualColumns @@ -331,9 +332,9 @@ def origin(self): else: freq = { "Y": ( - "Y" - if float(".".join(pd.__version__.split(".")[:-1])) >= 2.2 - else "A" + "A" + if version.Version(pd.__version__) >= version.Version("2.2.0") + else "Y" ), "S": "2Q", "H": "2Q", @@ -345,7 +346,7 @@ def origin(self): def origin(self, value): self._len_check(self.origin, value) freq = { - "Y": "Y" if float(".".join(pd.__version__.split(".")[:-1])) >= 2.2 else "A", + "Y": "Y" if float(".".join(pd.__version__.split(".")[:-1])) < 2.2 else "A", "S": "2Q", }.get(self.origin_grain, self.origin_grain) freq = freq if freq == "M" else freq + "-" + self.origin_close