Skip to content

Commit

Permalink
resolves #59 Add cl.concat functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
jbogaardt committed Mar 7, 2020
1 parent 0c66fbf commit c9edf4e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
3 changes: 1 addition & 2 deletions chainladder/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@
DataFrame, Series, Row, Column, Tabs, CSpacer, RSpacer, Title, Image,
VSpacer, HSpacer) # noqa (API import)
from chainladder.utils.utility_functions import ( # noqa (API import)
load_dataset, parallelogram_olf, read_pickle, read_json)
load_dataset, parallelogram_olf, read_pickle, read_json, concat)
from chainladder.utils.cupy import cp

5 changes: 5 additions & 0 deletions chainladder/utils/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ def test_pipeline_json_io():
for item in pipe.get_params()['steps']} == \
{item[0]: item[1].get_params()
for item in pipe2.get_params()['steps']}

def test_concat():
tri = cl.load_dataset('clrd').groupby('LOB').sum()
assert cl.concat([tri.loc['wkcomp'], tri.loc['comauto']], axis=0) == \
tri.loc[['wkcomp', 'comauto']]
35 changes: 35 additions & 0 deletions chainladder/utils/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import joblib
import json
import os
import copy
from chainladder.core.triangle import Triangle
from chainladder.workflow import Pipeline

Expand Down Expand Up @@ -128,3 +129,37 @@ def parallelogram_olf(values, date, start_date=None, end_date=None,
y.columns = ['Origin', 'OLF']
y['Origin'] = y['Origin'].astype(str)
return y.set_index('Origin')

def concat(objs, axis):
""" Concatenate Triangle objects along a particular axis.
Parameters
----------
objs : list or tuple
A list or tuple of Triangle objects to concat. All non-concat axes must
be identical and all elements of the concat axes must be unique.
axis : string or int
The axis along which to concatenate.
Returns
-------
Updated triangle
"""
xp = cp.get_array_module(objs[0].values)
axis = objs[0]._get_axis(axis)
mapper = {0: 'kdims', 1: 'vdims', 2: 'odims', 3: 'ddims'}
for k, v in mapper.items():
if k != axis: # All non-concat axes must be identical
assert xp.all(xp.array([getattr(obj, mapper[k]) for obj in objs]) ==
getattr(objs[0], mapper[k]))
else: # All elements of concat axis must be unique
new_axis = xp.concatenate([getattr(obj, mapper[axis]) for obj in objs])
if axis == 0:
assert len(pd.DataFrame(new_axis).drop_duplicates()) == len(new_axis)
else:
assert len(new_axis) == len(set(new_axis))
out = copy.deepcopy(objs[0])
out.values = xp.concatenate([obj.values for obj in objs], axis=axis)
setattr(out, mapper[axis], new_axis)
out._set_slicers()
return out

0 comments on commit c9edf4e

Please sign in to comment.