Skip to content

Commit

Permalink
copy commands/train.py
Browse files Browse the repository at this point in the history
Co-authored-by: Baudouin Raoult <[email protected]>
  • Loading branch information
theissenhelen and b8raoult committed Aug 1, 2024
1 parent 69b76a6 commit 10dc231
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ dynamic = [
"version",
]
dependencies = [
"anemoi-datasets>=0.1",
"anemoi-datasets[data]>=0.4",
"anemoi-models",
"anemoi-utils[provenance]>=0.3.7",
"anemoi-utils[provenance]>=0.3.10",
"einops>=0.6.1",
"hydra-core>=1.3",
"matplotlib>=3.7.1",
Expand Down
95 changes: 95 additions & 0 deletions src/anemoi/training/commands/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


import json
import logging
import os
import re

import hydra
from anemoi.utils.config import config_path
from omegaconf import OmegaConf

from . import Command

LOGGER = logging.getLogger(__name__)

# https://hydra.cc/docs/advanced/override_grammar/basic/

override_regex = re.compile(
r"""
^
(
(\w+)([/@:\.]\w+)* # key
= # assignment
(.*) # value
)
$
""",
re.VERBOSE,
)


class Train(Command):

def add_arguments(self, command_parser):
command_parser.add_argument(
"config",
nargs="*",
type=str,
help="A list yaml files to load or a list of overrides to apply",
)

def run(self, args):

configs = []
overrides = []

for config in args.config:
if override_regex.match(config):
overrides.append(config)
elif config.endswith(".yaml") or config.endswith(".yml"):
configs.append(config)
else:
raise ValueError(f"Invalid config '{config}'. It must be a yaml file or an override")

# We could apply the overrides here. To be tested
hydra.initialize(config_path="../config", version_base=None)

cfg = hydra.compose(config_name="config")

# Add user config
user_config = config_path("training.yaml")

if os.path.exists(user_config):
LOGGER.info(f"Loading config {user_config}")
cfg = OmegaConf.merge(cfg, OmegaConf.load(user_config))

# Add extra config files specified in the command line

for config in configs:
LOGGER.info(f"Loading config {config}")
cfg = OmegaConf.merge(cfg, OmegaConf.load(config))

# Apply overrides
# OmegaConf do not implement the prefix logic, this is done by hydra
# If needed, the logic can be implemented here (look in the git history for an example)
OmegaConf.merge(cfg, OmegaConf.from_dotlist(overrides))

# Resolve the config
OmegaConf.resolve(cfg)

print(json.dumps(OmegaConf.to_container(cfg), indent=4))

# AIFSTrainer(cfg).train()


command = Train

0 comments on commit 10dc231

Please sign in to comment.