From d6da33a90baca3181873662f08bc4fc5804efd4c Mon Sep 17 00:00:00 2001 From: David Hodel <33126037+Davee02@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:52:09 +0100 Subject: [PATCH] use all available datasets for computing the encoding of the categorical data in the metrics evaluater (#300) --- src/synthcity/metrics/eval.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/synthcity/metrics/eval.py b/src/synthcity/metrics/eval.py index 416aa989..94411bb7 100644 --- a/src/synthcity/metrics/eval.py +++ b/src/synthcity/metrics/eval.py @@ -203,12 +203,16 @@ def evaluate( """ We need to encode the categorical data in the real and synthetic data. - To ensure each category in the two datasets are mapped to the same one hot vector, we merge X_syn into X_gt for computing the encoder. - TODO: Check whether the optional datasets also need to be taking into account when getting the encoder. + To ensure each category in the two datasets are mapped to the same one hot vector, we merge all avalable datasets for computing the encoder. """ - X_gt_df = X_gt.dataframe() - X_syn_df = X_syn.dataframe() - X_enc = create_from_info(pd.concat([X_gt_df, X_syn_df]), X_gt.info()) + all_df = pd.concat([X_gt.dataframe(), X_syn.dataframe()]) + if X_train: + all_df = pd.concat([all_df, X_train.dataframe()]) + if X_ref_syn: + all_df = pd.concat([all_df, X_ref_syn.dataframe()]) + if X_augmented: + all_df = pd.concat([all_df, X_augmented.dataframe()]) + X_enc = create_from_info(all_df, X_gt.info()) _, encoders = X_enc.encode() # now we encode the data