Skip to content

Commit

Permalink
more dask support
Browse files Browse the repository at this point in the history
  • Loading branch information
jbogaard committed Jul 6, 2021
1 parent 95aac55 commit d50be78
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 44 deletions.
62 changes: 35 additions & 27 deletions chainladder/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ def __init__(
origin = development if origin is None else origin
origin_date = TriangleBase._to_datetime(data, origin, format=origin_format)
self.origin_grain = TriangleBase._get_grain(origin_date)
origin_date = (
pd.PeriodIndex(origin_date, freq=self.origin_grain)
.to_timestamp()
.rename("origin")
)
date_offset = {
'M': (pd.offsets.MonthEnd(0), pd.offsets.MonthBegin(1)),
'Q': (pd.offsets.QuarterEnd(0), pd.offsets.MonthBegin(3)),
'Y': (pd.offsets.YearEnd(0), pd.offsets.MonthBegin(12))
}[self.origin_grain]
origin_date = origin_date + date_offset[0] - date_offset[1]

# Initialize development and its grain
m_cnt = {"Y": 12, "Q": 3, "M": 1}
Expand All @@ -112,31 +113,35 @@ def __init__(
raise ValueError(
'Development lags could not be determined. This may be because development'
' is expressed as an age where a date-like vector is required')

# Summarize dataframe to the level specified in axes
key_gr = [origin_date, development_date] + [
data["__origin__"] = origin_date
data["__development__"] = development_date
key_gr = ["__origin__", "__development__"] + [
data[item] for item in ([] if not index else index)
]
data_agg = data[columns].groupby(key_gr).sum().reset_index().fillna(0)
data_agg = data.groupby(key_gr)[columns].sum().reset_index().fillna(0)
data = data.drop(['__origin__', '__development__'], 1)
if not index:
index = ["Total"]
data_agg[index[0]] = "Total"

# Fill in any gaps in origin/development
date_axes = self._get_date_axes(
data_agg["origin"], data_agg["development"]
data_agg["__origin__"], data_agg["__development__"]
) # cartesian product
dev_lag = TriangleBase._development_lag(
data_agg["origin"], data_agg["development"]
data_agg["__origin__"], data_agg["__development__"]
)

# Grab unique index, origin, development
dev_lag_unique = np.sort(
TriangleBase._development_lag(
date_axes["origin"], date_axes["development"]
date_axes["__origin__"], date_axes["__development__"]
).unique()
)

orig_unique = np.sort(date_axes["origin"].unique())
orig_unique = np.sort(date_axes["__origin__"].unique())
kdims = data_agg[index].drop_duplicates().reset_index(drop=True).reset_index()

# Map index, origin, development indices to data
Expand All @@ -145,19 +150,20 @@ def __init__(
.values[None]
.T
)
orig_idx = set_idx(data_agg["origin"], orig_unique)
orig_idx = set_idx(data_agg["__origin__"], orig_unique)
dev_idx = set_idx(dev_lag, dev_lag_unique)
key_idx = (
data_agg[index].merge(kdims, how="left", on=index)["index"].values[None].T
)

# origin <= development is required - truncate bad records if not true
valid = data_agg["origin"] <= data_agg["development"]
valid = data_agg["__origin__"] <= data_agg["__development__"]
if sum(~valid) > 0:
warnings.warn(
"Observations with development before "
+ "origin start have been removed."
)
valid = valid.compute() if hasattr(valid, 'compute') else valid
data_agg, orig_idx = data_agg[valid], orig_idx[valid]
dev_idx, key_idx = dev_idx[valid], key_idx[valid]

Expand All @@ -173,8 +179,7 @@ def __init__(
coords = np.concatenate(
(np.concatenate(tuple([key_idx] * len(columns)), 0), val_idx, coords), 1
)
amts = data_agg[columns].unstack()
amts = amts.values.astype("float64")
amts = np.concatenate([data_agg[col].fillna(0).values for col in data_agg[columns]]).astype("float64")
self.array_backend = "sparse"
self.values = num_to_nan(
sp(
Expand All @@ -193,7 +198,9 @@ def __init__(
)

# Set all axis values
self.valuation_date = data_agg["development"].max()
val_date = data_agg["__development__"].max()
val_date = val_date.compute() if hasattr(val_date, 'compute') else val_date
self.valuation_date = val_date
self.kdims = kdims.drop("index", 1).values
self.odims = orig_unique
self.ddims = dev_lag_unique if has_dev else dev_lag[0:1].values
Expand Down Expand Up @@ -249,23 +256,23 @@ def complete_date_range(
):
""" Determines origin/development combinations in full. Useful for
when the triangle has holes in it. """

o_min = origin_date.min()
o_max = origin_date.max()
d_max = development_date.max()
c = lambda x : x.compute() if hasattr(x, 'compute') else x
origin_unique = pd.period_range(
start=origin_date.min(),
end=max(origin_date.max(), development_date.max()),
freq=origin_grain,
start=c(o_min), end=max(c(o_max), c(d_max)), freq=origin_grain,
).to_timestamp()
development_unique = pd.period_range(
start=origin_date.min(),
end=development_date.max(),
freq=development_grain,
start=c(o_min), end=c(d_max), freq=development_grain,
).to_timestamp(how="e")
# Let's get rid of any development periods before origin periods
cart_prod = TriangleBase._cartesian_product(
origin_unique, development_unique
)
cart_prod = cart_prod[cart_prod[:, 0] <= cart_prod[:, 1], :]
return pd.DataFrame(cart_prod, columns=["origin", "development"])
return pd.DataFrame(cart_prod, columns=["__origin__", "__development__"])


cart_prod_o = complete_date_range(
pd.Series(origin_date.min()),
Expand All @@ -280,14 +287,14 @@ def complete_date_range(
self.development_grain,
)
cart_prod_t = pd.DataFrame(
{"origin": origin_date, "development": development_date}
{"__origin__": origin_date, "__development__": development_date}
)
cart_prod = (
cart_prod_o.append(cart_prod_d, sort=True)
.append(cart_prod_t, sort=True)
.drop_duplicates()
)
cart_prod = cart_prod[cart_prod["development"] >= cart_prod["origin"]]
cart_prod = cart_prod[cart_prod["__development__"] >= cart_prod["__origin__"]]
return cart_prod

@property
Expand Down Expand Up @@ -349,7 +356,8 @@ def _development_lag(origin, development):
year_diff = development.dt.year - origin.dt.year
quarter_diff = development.dt.quarter - origin.dt.quarter
month_diff = development.dt.month - origin.dt.month
if np.all(origin != development):
all = dp.all if hasattr(origin, 'compute') else np.all
if all(origin != development):
development_grain = TriangleBase._get_grain(development)
else:
development_grain = "M"
Expand Down
4 changes: 2 additions & 2 deletions chainladder/core/dunders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def _validate_arithmetic(self, other):
other = other.values
else:
if isinstance(other, np.ndarray) and self.array_backend != 'numpy':
other = self.get_array_module().array(other)
obj = self.copy()
other = obj.get_array_module().array(other)
elif isinstance(other, sp) and self.array_backend != 'sparse':
obj = self.set_backend('sparse')
else:
Expand Down Expand Up @@ -91,7 +91,7 @@ def _prep_columns(self, x, y):
elif len(y.columns) == 1 and len(x.columns) > 1:
y.vdims = x.vdims
elif len(y.columns) == 1 and len(x.columns) == 1 and x.columns != y.columns:
y.vdims = x.vdims = pd.RangeIndex(start=0, stop=1, step=1)
y.vdims = x.vdims
elif x.shape[1] == y.shape[1] and np.all(x.columns == y.columns):
pass
else:
Expand Down
8 changes: 4 additions & 4 deletions chainladder/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ def to_pickle(self, path, protocol=None):
The pickle protocol to use.
"""
out = self.copy()
out.virtual_columns = dill.dumps(out.virtual_columns)
joblib.dump(out, filename=path, protocol=protocol)
with open(path, "wb") as pkl:
dill.dump(self, pkl)

def to_json(self):
""" Serializes triangle object to json format
Expand Down Expand Up @@ -66,7 +65,8 @@ def to_pickle(self, path, protocol=None):
protocol :
The pickle protocol to use.
"""
joblib.dump(self, filename=path, protocol=protocol)
with open(path, "wb") as pkl:
dill.dump(self, pkl)

def to_json(self):
""" Serializes triangle object to json format
Expand Down
1 change: 1 addition & 0 deletions chainladder/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _set_ult_attr(self, ultimate):
ultimate.values[~xp.isfinite(ultimate.values)] = xp.nan
ultimate.ddims = pd.DatetimeIndex([ULT_VAL])
ultimate.virtual_columns.columns = {}
ultimate.is_cumulative = True
ultimate._set_slicers()
ultimate.valuation_date = ultimate.valuation.max()
ultimate._drop_subtriangles()
Expand Down
5 changes: 1 addition & 4 deletions chainladder/tails/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def fit(self, X, y=None, sample_weight=None):
"""
from chainladder.utils.utility_functions import num_to_nan

if X.array_backend == "sparse":
X = X.set_backend("numpy")
else:
X = X.copy()
X = X.copy()
xp = X.get_array_module()
if type(self.fit_period) == slice:
warnings.warn(
Expand Down
9 changes: 2 additions & 7 deletions chainladder/utils/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,8 @@ def load_dataset(key, *args, **kwargs):


def read_pickle(path):
out = joblib.load(path)
try:
out.virtual_columns = dill.loads(out.virtual_columns)
except:
from chainladder.core.slice import VirtualColumns
out.virtual_columns = VirtualColumns(out)
return out
with open(path, "rb") as pkl:
return dill.load(pkl)


def read_json(json_str, array_backend=None):
Expand Down

0 comments on commit d50be78

Please sign in to comment.