From d36f37fcd36dce0c57bdd346d2359cb85ccbc798 Mon Sep 17 00:00:00 2001 From: Rob <62107751+robsdavis@users.noreply.github.com> Date: Tue, 12 Sep 2023 19:05:01 +0100 Subject: [PATCH 1/2] bugfix in ts dataloader (#235) * remove concat from time series dataloader unpack * pin xgboost --- setup.cfg | 2 +- src/synthcity/plugins/core/dataloader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index a9c6fb5b..974f8198 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = pydantic<2.0 cloudpickle scipy - xgboost + xgboost<2.0.0 geomloss pgmpy redis diff --git a/src/synthcity/plugins/core/dataloader.py b/src/synthcity/plugins/core/dataloader.py index 40bee0a8..c01cede7 100644 --- a/src/synthcity/plugins/core/dataloader.py +++ b/src/synthcity/plugins/core/dataloader.py @@ -932,7 +932,7 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any: longest_observation_seq = max([len(seq) for seq in temporal_data]) return ( np.asarray(static_data), - np.asarray(pd.concat(temporal_data)), + np.asarray(temporal_data), # masked array to handle variable length sequences ma.vstack( [ From f7549cf9db2eba56fe3f73972397598e1e9bc885 Mon Sep 17 00:00:00 2001 From: Rob <62107751+robsdavis@users.noreply.github.com> Date: Tue, 12 Sep 2023 23:25:28 +0100 Subject: [PATCH 2/2] bump version.py --- src/synthcity/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/synthcity/version.py b/src/synthcity/version.py index 07aec94b..d89f15a8 100644 --- a/src/synthcity/version.py +++ b/src/synthcity/version.py @@ -1,4 +1,4 @@ -__version__ = "0.2.8" +__version__ = "0.2.9" MAJOR_VERSION = ".".join(__version__.split(".")[:-1]) PATCH_VERSION = __version__.split(".")[-1]