Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add command anemoi-training train ... #3

Closed
wants to merge 14 commits into from
24 changes: 24 additions & 0 deletions docs/cli/introduction.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Introduction
============

When you install the `anemoi-training` package, this will also install command line tool
called ``anemoi-training`` which can be used to train models.

The tool can provide help with the ``--help`` options:

.. code-block:: bash

% anemoi-training --help

The commands are:

.. toctree::
:maxdepth: 1

train

.. argparse::
:module: anemoi.training.__main__
:func: create_parser
:prog: anemoi-training
:nosubcommands:
49 changes: 49 additions & 0 deletions docs/cli/train.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#######
train
#######

Use this command to create a train a model:

.. code:: bash

% anemoi-training train config.yaml

The command will read the default configuration and override it with the
values in the provided configuration file. The configuration file should
be a YAML file with the structure defined in the `Configuration`
section. The file `config.yaml` will typically destribes the model to be
trained, the dataset to be used, and the training hyperparameters:

.. literalinclude:: train.yaml
:language: yaml

You can provide more that one configuration file, in which case the
values will be merged in the order they are provided. A typical usage
would be to split the training configurations into model description,
training hyperparameters and runtime options

.. code:: bash

% anemoi-training train model.yaml hyperparameters.yaml slurm.yaml

Furthermore, you can also provide values directly on the command line,
which will override any values in the configuration files:

.. code:: bash

% anemoi-training train config.yaml tracker.mlflow.tracking_uri=http://localhost:5000

If the file `~/.config/anemoi/train.yaml` exists, it will be loaded
after the defaults and before any other configuration file. This allows
you to provide values such as passwords or other sensitive information
that you do not want to store a git repository.

********************
Command line usage
********************

.. argparse::
:module: anemoi.training.__main__
:func: create_parser
:prog: anemoi-training
:path: train
2 changes: 2 additions & 0 deletions docs/cli/train.yaml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this "just an example"?

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
training:
max_epochs: 10
13 changes: 13 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ of the *Anemoi* packages.

installing

**Command line tool**

- :doc:`cli/introduction`
- :doc:`cli/train`

.. toctree::
:maxdepth: 1
:hidden:
:caption: Command line tool

cli/introduction
cli/train

*****************
Anemoi packages
*****************
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dynamic = [
]
dependencies = [
"anemoi-datasets[data]>=0.1",
"anemoi-models @ git+https://github.com/ecmwf/anemoi-models.git",
"anemoi-models",
"anemoi-utils[provenance]>=0.1.3",
"einops>=0.6.1",
"hydra-core>=1.3",
Expand Down
168 changes: 168 additions & 0 deletions src/anemoi/training/commands/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


import json
import logging
import os
import re

import hydra
from anemoi.utils.config import config_path
from hydra.errors import ConfigCompositionException
from omegaconf import OmegaConf

from . import Command

LOGGER = logging.getLogger(__name__)

# https://hydra.cc/docs/advanced/override_grammar/basic/

override_regex = re.compile(
r"""
^
(
(~|\+|\+\+)? # optional prefix
(\w+)([/@:\.]\w+)* # key
= # assignment
(.*) # value
)
| # or
(~ # ~ prefix
(\w+)([/@:\.]\w+)* # key
)
$
""",
re.VERBOSE,
)


def apply_delete_override(cfg, dotkey, value, value_given):

any_value = object()

if not value_given:
assert value is None
value = any_value

current = OmegaConf.select(cfg, dotkey, throw_on_missing=False)
if value not in (any_value, current):
raise ConfigCompositionException(
f"Key '{dotkey}' with value '{current}' does not match the value '{value}' in the override"
)

try:
# Allow 'del'
OmegaConf.set_struct(cfg, False)

if "." in dotkey:
parent, key = dotkey.rsplit(".", 1)
subtree = OmegaConf.select(cfg, parent)
del subtree[key]
else:
# Top level key
del cfg[dotkey]

finally:
OmegaConf.set_struct(cfg, True)


def apply_add_override_force(cfg, dotkey, value):
OmegaConf.update(cfg, dotkey, value, merge=True, force_add=True)


def apply_add_override(cfg, dotkey, value):
current = OmegaConf.select(cfg, dotkey, throw_on_missing=False)
if current is not None:
raise ConfigCompositionException(f"Cannot add key '{dotkey}' because it already exists, use '++' to force add")

OmegaConf.update(cfg, dotkey, value, merge=True, force_add=True)


def apply_assign_override(cfg, dotkey, value):
OmegaConf.update(cfg, dotkey, value, merge=True)


def parse_override(override, n):
dotkey = override[n:]
parsed = OmegaConf.from_dotlist([dotkey])
dotkey = dotkey.split("=")[0]
value = OmegaConf.select(parsed, dotkey)
return dotkey, value


def apply_override(cfg, override):
if override.startswith("~"):
return apply_delete_override(cfg, *parse_override(override, 1), value_given="=" in override)

if override.startswith("++"):
return apply_add_override_force(cfg, *parse_override(override, 2))

if override.startswith("+"):
return apply_add_override(cfg, *parse_override(override, 1))

return apply_assign_override(cfg, *parse_override(override, 0))


class Train(Command):

def add_arguments(self, command_parser):
command_parser.add_argument(
"config",
nargs="*",
type=str,
help="A list yaml files to load or a list of overrides to apply",
)

def run(self, args):

configs = []
overrides = []

for config in args.config:
if override_regex.match(config):
overrides.append(config)
elif config.endswith(".yaml") or config.endswith(".yml"):
configs.append(config)
else:
raise ValueError(f"Invalid config '{config}'. It must be a yaml file or an override")

hydra.initialize(config_path="../config", version_base=None)

cfg = hydra.compose(config_name="config")

# Add user config
user_config = config_path("train.yaml")

if os.path.exists(user_config):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use Pathlib

LOGGER.info(f"Loading config {user_config}")
cfg = OmegaConf.merge(cfg, OmegaConf.load(user_config, resolve=True))

# Add extra config files specified in the command line

for config in configs:
LOGGER.info(f"Loading config {config}")
cfg = OmegaConf.merge(cfg, OmegaConf.load(config))

# We need to reapply the overrides
# OmegaConf do not implement the prefix logic, this is done by hydra
for override in overrides:
LOGGER.info(f"Applying override {override}")
apply_override(cfg, override)

# Resolve the config
OmegaConf.resolve(cfg)

print(json.dumps(OmegaConf.to_container(cfg), indent=4))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably use the logger if we want this as an output.


# AIFSTrainer(cfg).train()


command = Train
Empty file.
16 changes: 16 additions & 0 deletions src/anemoi/training/config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
defaults:
- _self_

model:
num_channels: 128

dataloader:
limit_batches:
training: 100
validation: 100

training:
max_epochs: 3

token:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is just an early example, but fyi the token will not be stored in the user config (it will be in its own config file and generated by code with the anemoi-training mlflow login cmd, not input by the user)

mlflow: 8
Loading