diff --git a/src/rsatoolbox/rdm/compare.py b/src/rsatoolbox/rdm/compare.py index 3d7f8420..a6253c76 100644 --- a/src/rsatoolbox/rdm/compare.py +++ b/src/rsatoolbox/rdm/compare.py @@ -671,12 +671,14 @@ def _parse_input_rdms(rdm1, rdm2): vector2 = rdm2 if not vector1.shape[1] == vector2.shape[1]: raise ValueError('rdm1 and rdm2 must be RDMs of equal shape') - nan_idx = ~np.isnan(vector1) - vector1_no_nan = vector1[nan_idx].reshape(vector1.shape[0], -1) - vector2_no_nan = vector2[~np.isnan(vector2)].reshape(vector2.shape[0], -1) - if not vector1_no_nan.shape[1] == vector2_no_nan.shape[1]: + # A NaN in any RDM means that position must be excluded from all + nan_mask = ~np.isnan(vector1).any(axis=0) + if not np.all(nan_mask == ~np.isnan(vector2).any(axis=0)): + # Only raise error when rdm1 and rdm2 conflict raise ValueError('rdm1 and rdm2 have different nan positions') - return vector1_no_nan, vector2_no_nan, nan_idx[0] + vector1_no_nan = vector1[:,nan_mask].reshape(vector1.shape[0], -1) + vector2_no_nan = vector2[:,nan_mask].reshape(vector2.shape[0], -1) + return vector1_no_nan, vector2_no_nan, nan_mask def _sq_bures_metric_first_way(A, B):