From e36982b87bf87fb9559fc4d124e132b67f177d23 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Thu, 12 Oct 2023 10:02:23 +0100 Subject: [PATCH] update validation epoch (#7121) - this allows for validation epoch at the very beginning of training - fixes #7122 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li --- monai/apps/utils.py | 2 ++ monai/engines/evaluator.py | 2 +- tests/test_handler_validation.py | 6 ++++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index eee004f27d..d2dd63b958 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -203,6 +203,8 @@ def download_url( if urlparse(url).netloc == "drive.google.com": if not has_gdown: raise RuntimeError("To download files from Google Drive, please install the gdown dependency.") + if "fuzzy" not in gdown_kwargs: + gdown_kwargs["fuzzy"] = True # default to true for flexible url gdown.download(url, f"{tmp_name}", quiet=not progress, **gdown_kwargs) elif urlparse(url).netloc == "cloud-api.yandex.net": with urlopen(url) as response: diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 736cde8b88..119853d5c5 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -142,7 +142,7 @@ def run(self, global_epoch: int = 1) -> None: # type: ignore[override] """ # init env value for current validation process - self.state.max_epochs = global_epoch + self.state.max_epochs = max(global_epoch, 1) # at least one epoch of validation self.state.epoch = global_epoch - 1 self.state.iteration = 0 super().run() diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py index 23fcf5e75c..e1ccba2294 100644 --- a/tests/test_handler_validation.py +++ b/tests/test_handler_validation.py @@ -23,7 +23,8 @@ class TestEvaluator(Evaluator): def _iteration(self, engine, batchdata): - pass + engine.state.output = "called" + return engine.state.output class TestHandlerValidation(unittest.TestCase): @@ -42,8 +43,9 @@ def _train_func(engine, batch): ValidationHandler(interval=2, validator=evaluator, exec_at_start=True).attach(engine) # test execution at start engine.run(data, max_epochs=1) - self.assertEqual(evaluator.state.max_epochs, 0) + self.assertEqual(evaluator.state.max_epochs, 1) self.assertEqual(evaluator.state.epoch_length, 8) + self.assertEqual(evaluator.state.output, "called") engine.run(data, max_epochs=5) self.assertEqual(evaluator.state.max_epochs, 4)