From ba949f7ab89827a9bf4643110a58d2eefbd63891 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 28 Jun 2024 10:44:34 +0100 Subject: [PATCH] Better check of config arguments --- src/anemoi/training/commands/train.py | 8 ++++++-- src/anemoi/training/config/config.yaml | 6 ++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/anemoi/training/commands/train.py b/src/anemoi/training/commands/train.py index 07ea078a..3b44434e 100644 --- a/src/anemoi/training/commands/train.py +++ b/src/anemoi/training/commands/train.py @@ -61,12 +61,15 @@ def run(self, args): for config in args.config: if override_regex.match(config): overrides.append(config) - else: + elif config.endswith(".yaml") or config.endswith(".yml"): configs.append(config) + else: + raise ValueError(f"Invalid config '{config}'. It must be a yaml file or an override") hydra.initialize(config_path="../config", version_base=None) - cfg = hydra.compose(config_name="config", overrides=overrides) + cfg = hydra.compose(config_name="config") # , overrides=overrides) + print(cfg) # Add user config user_config = config_path("training.yaml") @@ -82,6 +85,7 @@ def run(self, args): cfg = OmegaConf.merge(cfg, OmegaConf.load(config)) # We need to reapply the overrides + # This does not support overrides with a prefix cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(overrides)) print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=4)) diff --git a/src/anemoi/training/config/config.yaml b/src/anemoi/training/config/config.yaml index 7d0c7f22..604b5f20 100644 --- a/src/anemoi/training/config/config.yaml +++ b/src/anemoi/training/config/config.yaml @@ -1,14 +1,16 @@ defaults: - - _self_ +- _self_ model: num_channels: 128 + dataloader: limit_batches: training: 100 validation: 100 + training: max_epochs: 3 token: - mflow: None + mlflow: null