From d36f37fcd36dce0c57bdd346d2359cb85ccbc798 Mon Sep 17 00:00:00 2001 From: Rob <62107751+robsdavis@users.noreply.github.com> Date: Tue, 12 Sep 2023 19:05:01 +0100 Subject: [PATCH] bugfix in ts dataloader (#235) * remove concat from time series dataloader unpack * pin xgboost --- setup.cfg | 2 +- src/synthcity/plugins/core/dataloader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index a9c6fb5b..974f8198 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = pydantic<2.0 cloudpickle scipy - xgboost + xgboost<2.0.0 geomloss pgmpy redis diff --git a/src/synthcity/plugins/core/dataloader.py b/src/synthcity/plugins/core/dataloader.py index 40bee0a8..c01cede7 100644 --- a/src/synthcity/plugins/core/dataloader.py +++ b/src/synthcity/plugins/core/dataloader.py @@ -932,7 +932,7 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any: longest_observation_seq = max([len(seq) for seq in temporal_data]) return ( np.asarray(static_data), - np.asarray(pd.concat(temporal_data)), + np.asarray(temporal_data), # masked array to handle variable length sequences ma.vstack( [