-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add command anemoi-training train ...
#3
Changes from 11 commits
1046ef2
849c20b
4d02071
e563291
25e552e
ba949f7
c7718df
6cecb14
9e9c1b5
d7dc7cb
cca1698
4e91faf
b9e10e5
4bb1332
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
Introduction | ||
============ | ||
|
||
When you install the `anemoi-training` package, this will also install command line tool | ||
called ``anemoi-training`` which can be used to train models. | ||
|
||
The tool can provide help with the ``--help`` options: | ||
|
||
.. code-block:: bash | ||
|
||
% anemoi-training --help | ||
|
||
The commands are: | ||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
|
||
train | ||
|
||
.. argparse:: | ||
:module: anemoi.training.__main__ | ||
:func: create_parser | ||
:prog: anemoi-training | ||
:nosubcommands: |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
####### | ||
train | ||
####### | ||
|
||
Use this command to create a train a model: | ||
|
||
.. code:: bash | ||
|
||
% anemoi-training train config.yaml | ||
|
||
The command will read the default configuration and override it with the | ||
values in the provided configuration file. The configuration file should | ||
be a YAML file with the structure defined in the `Configuration` | ||
section. The file `config.yaml` will typically destribes the model to be | ||
trained, the dataset to be used, and the training hyperparameters: | ||
|
||
.. literalinclude:: train.yaml | ||
:language: yaml | ||
|
||
You can provide more that one configuration file, in which case the | ||
values will be merged in the order they are provided. A typical usage | ||
would be to split the training configurations into model description, | ||
training hyperparameters and runtime options | ||
|
||
.. code:: bash | ||
|
||
% anemoi-training train model.yaml hyperparameters.yaml slurm.yaml | ||
|
||
Furthermore, you can also provide values directly on the command line, | ||
which will override any values in the configuration files: | ||
|
||
.. code:: bash | ||
|
||
% anemoi-training train config.yaml tracker.mlflow.tracking_uri=http://localhost:5000 | ||
|
||
If the file `~/.config/anemoi/train.yaml` exists, it will be loaded | ||
after the defaults and before any other configuration file. This allows | ||
you to provide values such as passwords or other sensitive information | ||
that you do not want to store a git repository. | ||
|
||
******************** | ||
Command line usage | ||
******************** | ||
|
||
.. argparse:: | ||
:module: anemoi.training.__main__ | ||
:func: create_parser | ||
:prog: anemoi-training | ||
:path: train |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
training: | ||
max_epochs: 10 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
#!/usr/bin/env python | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
|
||
import json | ||
import logging | ||
import os | ||
import re | ||
|
||
import hydra | ||
from anemoi.utils.config import config_path | ||
from hydra.errors import ConfigCompositionException | ||
from omegaconf import OmegaConf | ||
|
||
from . import Command | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
# https://hydra.cc/docs/advanced/override_grammar/basic/ | ||
|
||
override_regex = re.compile( | ||
r""" | ||
^ | ||
( | ||
(~|\+|\+\+)? # optional prefix | ||
(\w+)([/@:\.]\w+)* # key | ||
= # assignment | ||
(.*) # value | ||
) | ||
| # or | ||
(~ # ~ prefix | ||
(\w+)([/@:\.]\w+)* # key | ||
) | ||
$ | ||
""", | ||
re.VERBOSE, | ||
) | ||
|
||
|
||
def apply_delete_override(cfg, dotkey, value, value_given): | ||
|
||
any_value = object() | ||
|
||
if not value_given: | ||
assert value is None | ||
value = any_value | ||
|
||
current = OmegaConf.select(cfg, dotkey, throw_on_missing=False) | ||
if value not in (any_value, current): | ||
raise ConfigCompositionException( | ||
f"Key '{dotkey}' with value '{current}' does not match the value '{value}' in the override" | ||
) | ||
|
||
try: | ||
# Allow 'del' | ||
OmegaConf.set_struct(cfg, False) | ||
|
||
if "." in dotkey: | ||
parent, key = dotkey.rsplit(".", 1) | ||
subtree = OmegaConf.select(cfg, parent) | ||
del subtree[key] | ||
else: | ||
# Top level key | ||
del cfg[dotkey] | ||
|
||
finally: | ||
OmegaConf.set_struct(cfg, True) | ||
|
||
|
||
def apply_add_override_force(cfg, dotkey, value): | ||
OmegaConf.update(cfg, dotkey, value, merge=True, force_add=True) | ||
|
||
|
||
def apply_add_override(cfg, dotkey, value): | ||
current = OmegaConf.select(cfg, dotkey, throw_on_missing=False) | ||
if current is not None: | ||
raise ConfigCompositionException(f"Cannot add key '{dotkey}' because it already exists, use '++' to force add") | ||
|
||
OmegaConf.update(cfg, dotkey, value, merge=True, force_add=True) | ||
|
||
|
||
def apply_assign_override(cfg, dotkey, value): | ||
OmegaConf.update(cfg, dotkey, value, merge=True) | ||
|
||
|
||
def parse_override(override, n): | ||
dotkey = override[n:] | ||
parsed = OmegaConf.from_dotlist([dotkey]) | ||
dotkey = dotkey.split("=")[0] | ||
value = OmegaConf.select(parsed, dotkey) | ||
return dotkey, value | ||
|
||
|
||
def apply_override(cfg, override): | ||
if override.startswith("~"): | ||
return apply_delete_override(cfg, *parse_override(override, 1), value_given="=" in override) | ||
|
||
if override.startswith("++"): | ||
return apply_add_override_force(cfg, *parse_override(override, 2)) | ||
|
||
if override.startswith("+"): | ||
return apply_add_override(cfg, *parse_override(override, 1)) | ||
|
||
return apply_assign_override(cfg, *parse_override(override, 0)) | ||
|
||
|
||
class Train(Command): | ||
|
||
def add_arguments(self, command_parser): | ||
command_parser.add_argument( | ||
"config", | ||
nargs="*", | ||
type=str, | ||
help="A list yaml files to load or a list of overrides to apply", | ||
) | ||
|
||
def run(self, args): | ||
|
||
configs = [] | ||
overrides = [] | ||
|
||
for config in args.config: | ||
if override_regex.match(config): | ||
overrides.append(config) | ||
elif config.endswith(".yaml") or config.endswith(".yml"): | ||
configs.append(config) | ||
else: | ||
raise ValueError(f"Invalid config '{config}'. It must be a yaml file or an override") | ||
|
||
hydra.initialize(config_path="../config", version_base=None) | ||
|
||
cfg = hydra.compose(config_name="config") | ||
|
||
# Add user config | ||
user_config = config_path("train.yaml") | ||
|
||
if os.path.exists(user_config): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should use Pathlib |
||
LOGGER.info(f"Loading config {user_config}") | ||
cfg = OmegaConf.merge(cfg, OmegaConf.load(user_config, resolve=True)) | ||
|
||
# Add extra config files specified in the command line | ||
|
||
for config in configs: | ||
LOGGER.info(f"Loading config {config}") | ||
cfg = OmegaConf.merge(cfg, OmegaConf.load(config)) | ||
|
||
# We need to reapply the overrides | ||
# OmegaConf do not implement the prefix logic, this is done by hydra | ||
for override in overrides: | ||
LOGGER.info(f"Applying override {override}") | ||
apply_override(cfg, override) | ||
|
||
# Resolve the config | ||
OmegaConf.resolve(cfg) | ||
|
||
print(json.dumps(OmegaConf.to_container(cfg), indent=4)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably use the logger if we want this as an output. |
||
|
||
# AIFSTrainer(cfg).train() | ||
|
||
|
||
command = Train |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
defaults: | ||
- _self_ | ||
|
||
model: | ||
num_channels: 128 | ||
|
||
dataloader: | ||
limit_batches: | ||
training: 100 | ||
validation: 100 | ||
|
||
training: | ||
max_epochs: 3 | ||
|
||
token: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is just an early example, but fyi the token will not be stored in the user config (it will be in its own config file and generated by code with the |
||
mlflow: 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this "just an example"?