diff --git a/src/synthcity/metrics/eval.py b/src/synthcity/metrics/eval.py index 9f57b7af..c6d0fbd3 100644 --- a/src/synthcity/metrics/eval.py +++ b/src/synthcity/metrics/eval.py @@ -200,15 +200,26 @@ def evaluate( if metrics is None: metrics = Metrics.list() - X_gt, _ = X_gt.encode() - X_syn, _ = X_syn.encode() + """ + We need to encode the categorical data in the real and synthetic data. + To ensure each category in the two datasets are mapped to the same one hot vector, we merge X_syn into X_gt for computing the encoder. + TODO: Check whether the optional datasets also need to be taking into account when getting the encoder. + """ + X_gt_df = X_gt.dataframe() + X_syn_df = X_syn.dataframe() + X_enc = create_from_info(pd.concat([X_gt_df, X_syn_df]), X_gt.info()) + _, encoders = X_enc.encode() + + # now we encode the data + X_gt, _ = X_gt.encode(encoders) + X_syn, _ = X_syn.encode(encoders) if X_train: - X_train, _ = X_train.encode() + X_train, _ = X_train.encode(encoders) if X_ref_syn: - X_ref_syn, _ = X_ref_syn.encode() + X_ref_syn, _ = X_ref_syn.encode(encoders) if X_augmented: - X_augmented, _ = X_augmented.encode() + X_augmented, _ = X_augmented.encode(encoders) scores = ScoreEvaluator()