diff --git a/src/synthcity/metrics/eval.py b/src/synthcity/metrics/eval.py index 416aa989..94411bb7 100644 --- a/src/synthcity/metrics/eval.py +++ b/src/synthcity/metrics/eval.py @@ -203,12 +203,16 @@ def evaluate( """ We need to encode the categorical data in the real and synthetic data. - To ensure each category in the two datasets are mapped to the same one hot vector, we merge X_syn into X_gt for computing the encoder. - TODO: Check whether the optional datasets also need to be taking into account when getting the encoder. + To ensure each category in the two datasets are mapped to the same one hot vector, we merge all avalable datasets for computing the encoder. """ - X_gt_df = X_gt.dataframe() - X_syn_df = X_syn.dataframe() - X_enc = create_from_info(pd.concat([X_gt_df, X_syn_df]), X_gt.info()) + all_df = pd.concat([X_gt.dataframe(), X_syn.dataframe()]) + if X_train: + all_df = pd.concat([all_df, X_train.dataframe()]) + if X_ref_syn: + all_df = pd.concat([all_df, X_ref_syn.dataframe()]) + if X_augmented: + all_df = pd.concat([all_df, X_augmented.dataframe()]) + X_enc = create_from_info(all_df, X_gt.info()) _, encoders = X_enc.encode() # now we encode the data