From 598627bc38322c868fa9be0f8ffacedd5263adc6 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 25 Oct 2024 10:47:36 -0400 Subject: [PATCH] Add `fast_dev_run=False` in default trainer config (#74) * Add `fast_dev_run=False` in default trainer config Signed-off-by: Fabrice Normandin * Add xfail for flaky integration test Signed-off-by: Fabrice Normandin * Adjust the fast_dev_run in `test_run_remote_job` Signed-off-by: Fabrice Normandin --------- Signed-off-by: Fabrice Normandin --- project/algorithms/jax_rl_example_test.py | 1 + project/configs/trainer/default.yaml | 2 ++ project/main_test.py | 2 +- project/utils/remote_launcher_plugin_test.py | 2 +- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/project/algorithms/jax_rl_example_test.py b/project/algorithms/jax_rl_example_test.py index e70c17e7..e859e658 100644 --- a/project/algorithms/jax_rl_example_test.py +++ b/project/algorithms/jax_rl_example_test.py @@ -700,6 +700,7 @@ def lightning_trainer(max_epochs: int, tmp_path: Path): # reducing the max_epochs from 75 down to 10 because it's just wayyy too slow. +@pytest.mark.xfail(reason="Seems to not be completely reproducible") @pytest.mark.slow # @pytest.mark.timeout(80) @pytest.mark.parametrize("max_epochs", [15], indirect=True) diff --git a/project/configs/trainer/default.yaml b/project/configs/trainer/default.yaml index a9cdc147..1b463ff3 100644 --- a/project/configs/trainer/default.yaml +++ b/project/configs/trainer/default.yaml @@ -8,6 +8,8 @@ devices: 1 deterministic: true +fast_dev_run: false + min_epochs: 1 max_epochs: 10 diff --git a/project/main_test.py b/project/main_test.py index 94ad53cd..4e0f41c3 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -72,7 +72,7 @@ def test_example_experiment_defaults(experiment_config: Config) -> None: ) -@use_overrides(["algorithm=example datamodule=cifar10 seed=1 +trainer.fast_dev_run=True"]) +@use_overrides(["algorithm=example datamodule=cifar10 seed=1 trainer.fast_dev_run=True"]) def test_fast_dev_run(experiment_dictconfig: DictConfig): result = main(experiment_dictconfig) assert isinstance(result, dict) diff --git a/project/utils/remote_launcher_plugin_test.py b/project/utils/remote_launcher_plugin_test.py index a412bf00..d939b2cc 100644 --- a/project/utils/remote_launcher_plugin_test.py +++ b/project/utils/remote_launcher_plugin_test.py @@ -99,7 +99,7 @@ def test_can_load_configs(command_line_arguments: list[str]): # otherwise it will use the local launcher! "resources=gpu", "cluster=mila", - "+trainer.fast_dev_run=True", + "trainer.fast_dev_run=True", ] ], )