From 139552596b7b448f71e23413064da967b77ee5e7 Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Wed, 28 Feb 2024 11:05:45 +0100 Subject: [PATCH] Fix deterministic group_shuffle_split (#1839) * order sets * suggestion * add unit test * fix * updated test * fix * indices from file * test util update * path * no file * no file * comment * i cannot spell --- tests/datamodules/test_utils.py | 34 ++++++++++++++++++++------------- torchgeo/datamodules/utils.py | 4 ++-- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tests/datamodules/test_utils.py b/tests/datamodules/test_utils.py index ea410fe4b1d..b96c6032c67 100644 --- a/tests/datamodules/test_utils.py +++ b/tests/datamodules/test_utils.py @@ -30,8 +30,11 @@ def test_dataset_split() -> None: def test_group_shuffle_split() -> None: - alphabet = np.array(list("abcdefghijklmnopqrstuvwxyz")) - groups = np.random.randint(0, 26, size=(1000)) + train_indices = [0, 2, 5, 6, 7, 8, 9, 10, 11, 13, 14] + test_indices = [1, 3, 4, 12] + np.random.seed(0) + alphabet = np.array(list("abc")) + groups = np.random.randint(0, 3, size=(15)) groups = alphabet[groups] with pytest.raises(ValueError, match="You must specify `train_size` *"): @@ -43,16 +46,21 @@ def test_group_shuffle_split() -> None: match=re.escape("`train_size` and `test_size` must be in the range (0,1)."), ): group_shuffle_split(groups, train_size=-0.2, test_size=1.2) - with pytest.raises(ValueError, match="26 groups were found, however the current *"): + with pytest.raises(ValueError, match="3 groups were found, however the current *"): group_shuffle_split(groups, train_size=None, test_size=0.999) - train_indices, test_indices = group_shuffle_split( - groups, train_size=None, test_size=0.2 - ) - assert len(set(train_indices) & set(test_indices)) == 0 - assert len(set(groups[train_indices])) == 21 - train_indices, test_indices = group_shuffle_split( - groups, train_size=0.8, test_size=None - ) - assert len(set(train_indices) & set(test_indices)) == 0 - assert len(set(groups[train_indices])) == 21 + test_cases = [(None, 0.2, 42), (0.8, None, 42)] + + for train_size, test_size, random_state in test_cases: + train_indices1, test_indices1 = group_shuffle_split( + groups, + train_size=train_size, + test_size=test_size, + random_state=random_state, + ) + # Check that the results are the same as expected + assert np.array_equal(train_indices, train_indices1) + assert np.array_equal(test_indices, test_indices1) + + assert len(set(train_indices1) & set(test_indices1)) == 0 + assert len(set(groups[train_indices1])) == 2 diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index d0bb6af9934..8a002b3c4a6 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -102,7 +102,7 @@ def group_shuffle_split( if train_size <= 0 or train_size >= 1 or test_size <= 0 or test_size >= 1: raise ValueError("`train_size` and `test_size` must be in the range (0,1).") - group_vals = set(groups) + group_vals = sorted(set(groups)) n_groups = len(group_vals) n_test_groups = round(n_groups * test_size) n_train_groups = n_groups - n_test_groups @@ -115,7 +115,7 @@ def group_shuffle_split( generator = np.random.default_rng(seed=random_state) train_group_vals = set( - generator.choice(list(group_vals), size=n_train_groups, replace=False) + generator.choice(group_vals, size=n_train_groups, replace=False) ) train_idxs = []