Skip to content

Commit

Permalink
add metadata saving to compiled model
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Apr 23, 2024
1 parent ecc6e41 commit 6878c75
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 106 deletions.
2 changes: 1 addition & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
from mace.tools.finetuning_utils import extract_load
from mace.tools.scripts_utils import extract_load


def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
Expand Down
48 changes: 41 additions & 7 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Optional

import numpy as np
import torch.distributed
import torch.nn.functional
from e3nn import o3
from e3nn.util import jit
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.swa_utils import SWALR, AveragedModel
from torch_ema import ExponentialMovingAverage
Expand All @@ -24,14 +26,17 @@
from mace import data, modules, tools
from mace.calculators.foundations_models import mace_mp, mace_off
from mace.tools import torch_geometric
from mace.tools.finetuning_utils import extract_config_mace_model, load_foundations
from mace.tools.finetuning_utils import load_foundations
from mace.tools.scripts_utils import (
LRScheduler,
convert_to_json_format,
create_error_table,
extract_config_mace_model,
get_atomic_energies,
get_config_type_weights,
get_dataset_from_xyz,
get_files_with_suffix,
print_git_commit,
)
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.utils import AtomicNumberTable
Expand Down Expand Up @@ -72,7 +77,7 @@ def main() -> None:

tools.set_default_dtype(args.default_dtype)
device = tools.init_device(args.device)

commit = print_git_commit()
if args.foundation_model is not None:
if args.foundation_model in ["small", "medium", "large"]:
logging.info(
Expand Down Expand Up @@ -559,15 +564,13 @@ def main() -> None:
assert dipole_only is False, "swa for dipole fitting not implemented"
swas.append(True)
if args.start_swa is None:
args.start_swa = (
args.max_num_epochs // 4 * 3
) # if not set start swa at 75% of training
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
else:
if args.start_swa > args.max_num_epochs:
logging.info(
f"Start swa must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}"
)
args.start_swa = args.max_num_epochs // 4 * 3
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
logging.info(f"Setting start swa to {args.start_swa}")
if args.loss == "forces_only":
logging.info("Can not select swa with forces only loss.")
Expand Down Expand Up @@ -786,11 +789,42 @@ def main() -> None:
if args.save_cpu:
model = model.to("cpu")
torch.save(model, model_path)

extra_files = {
"commit.txt": commit.encode("utf-8"),
"config.yaml": json.dumps(
convert_to_json_format(extract_config_mace_model(model))
),
}
if swa_eval:
torch.save(model, Path(args.model_dir) / (args.name + "_swa.model"))
try:
path_complied = Path(args.model_dir) / (
args.name + "_swa_compiled.model"
)
logging.info(f"Compiling model, saving metadata {path_complied}")
model_compiled = jit.compile(deepcopy(model))
torch.jit.save(
model_compiled,
path_complied,
_extra_files=extra_files,
)
except Exception as e: # pylint: disable=W0703
pass
else:
torch.save(model, Path(args.model_dir) / (args.name + ".model"))
try:
path_complied = Path(args.model_dir) / (
args.name + "_compiled.model"
)
logging.info(f"Compiling model, saving metadata to {path_complied}")
model_compiled = jit.compile(deepcopy(model))
torch.jit.save(
model_compiled,
path_complied,
_extra_files=extra_files,
)
except Exception as e: # pylint: disable=W070344
pass

if args.distributed:
torch.distributed.barrier()
Expand Down
17 changes: 14 additions & 3 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,32 @@ def load_from_xyz(
)
energy_key = "REF_energy"
for atoms in atoms_list:
atoms.info["REF_energy"] = atoms.get_potential_energy()
try:
atoms.info["REF_energy"] = atoms.get_potential_energy()
except Exception as e: # pylint: disable=W0703
logging.warning(f"Failed to extract energy: {e}")
atoms.info["REF_energy"] = None
if forces_key == "forces":
logging.info(
"Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to 'REF_forces'"
)
forces_key = "REF_forces"
for atoms in atoms_list:
atoms.info["REF_forces"] = atoms.get_forces()
try:
atoms.info["REF_forces"] = atoms.get_forces()
except Exception as e: # pylint: disable=W0703
logging.warning(f"Failed to extract forces: {e}")
atoms.info["REF_forces"] = None
if stress_key == "stress":
logging.info(
"Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to 'REF_stress'"
)
stress_key = "REF_stress"
for atoms in atoms_list:
atoms.info["REF_stress"] = atoms.get_stress()
try:
atoms.info["REF_stress"] = atoms.get_stress()
except Exception as e: # pylint: disable=W0703
atoms.info["REF_stress"] = None
if not isinstance(atoms_list, list):
atoms_list = [atoms_list]

Expand Down
3 changes: 1 addition & 2 deletions mace/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser
from .cg import U_matrix_real
from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState
from .finetuning_utils import extract_load, load_foundations
from .finetuning_utils import load_foundations
from .torch_tools import (
TensorDict,
cartesian_to_spherical,
Expand Down Expand Up @@ -66,6 +66,5 @@
"voigt_to_matrix",
"init_wandb",
"load_foundations",
"extract_load",
"build_preprocess_arg_parser",
]
120 changes: 29 additions & 91 deletions mace/tools/finetuning_utils.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,8 @@
from typing import Any, Dict

import torch
from e3nn import o3

from mace.tools.utils import AtomicNumberTable


def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]:
def radial_to_name(radial_type):
if radial_type == "BesselBasis":
return "bessel"
if radial_type == "GaussianBasis":
return "gaussian"
if radial_type == "ChebychevBasis":
return "chebyshev"
return radial_type

def radial_to_transform(radial):
if not hasattr(radial, "distance_transform"):
return None
if radial.distance_transform.__class__.__name__ == "AgnesiTransform":
return "Agnesi"
if radial.distance_transform.__class__.__name__ == "SoftTransform":
return "Soft"
return radial.distance_transform.__class__.__name__

config = {
"r_max": model.r_max.item(),
"num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights),
"num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(),
"max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access
"interaction_cls": model.interactions[-1].__class__,
"interaction_cls_first": model.interactions[0].__class__,
"num_interactions": model.num_interactions.item(),
"num_elements": len(model.atomic_numbers),
"hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)),
"MLP_irreps": o3.Irreps(str(model.readouts[-1].hidden_irreps)),
"gate": model.readouts[-1] # pylint: disable=protected-access
.non_linearity._modules["acts"][0]
.f,
"atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(),
"avg_num_neighbors": model.interactions[0].avg_num_neighbors,
"atomic_numbers": model.atomic_numbers,
"correlation": len(
model.products[0].symmetric_contractions.contractions[0].weights
)
+ 1,
"radial_type": radial_to_name(
model.radial_embedding.bessel_fn.__class__.__name__
),
"radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1],
"pair_repulsion": hasattr(model, "pair_repulsion_fn"),
"distance_transform": radial_to_transform(model.radial_embedding),
"atomic_inter_scale": model.scale_shift.scale.item(),
"atomic_inter_shift": model.scale_shift.shift.item(),
}
return config


def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module:
model = torch.load(f=f, map_location=map_location)
model_copy = model.__class__(**extract_config_mace_model(model))
model_copy.load_state_dict(model.state_dict())
return model_copy.to(map_location)


def load_foundations(
model: torch.nn.Module,
model_foundations: torch.nn.Module,
Expand Down Expand Up @@ -112,24 +50,24 @@ def load_foundations(
for j in range(4): # Assuming 4 layers in conv_tp_weights,
layer_name = f"layer{j}"
if j == 0:
getattr(model.interactions[i].conv_tp_weights, layer_name).weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].conv_tp_weights,
layer_name,
)
.weight[:num_radial, :]
.clone()
getattr(
model.interactions[i].conv_tp_weights, layer_name
).weight = torch.nn.Parameter(
getattr(
model_foundations.interactions[i].conv_tp_weights,
layer_name,
)
.weight[:num_radial, :]
.clone()
)
else:
getattr(model.interactions[i].conv_tp_weights, layer_name).weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].conv_tp_weights,
layer_name,
).weight.clone()
)
getattr(
model.interactions[i].conv_tp_weights, layer_name
).weight = torch.nn.Parameter(
getattr(
model_foundations.interactions[i].conv_tp_weights,
layer_name,
).weight.clone()
)

model.interactions[i].linear.weight = torch.nn.Parameter(
Expand Down Expand Up @@ -167,23 +105,23 @@ def load_foundations(
for i in range(2): # Assuming 2 products modules
max_range = max_L + 1 if i == 0 else 1
for j in range(max_range): # Assuming 3 contractions in symmetric_contractions
model.products[i].symmetric_contractions.contractions[j].weights_max = (
torch.nn.Parameter(
model_foundations.products[i]
.symmetric_contractions.contractions[j]
.weights_max[indices_weights, :, :]
.clone()
)
model.products[i].symmetric_contractions.contractions[
j
].weights_max = torch.nn.Parameter(
model_foundations.products[i]
.symmetric_contractions.contractions[j]
.weights_max[indices_weights, :, :]
.clone()
)

for k in range(2): # Assuming 2 weights in each contraction
model.products[i].symmetric_contractions.contractions[j].weights[k] = (
torch.nn.Parameter(
model_foundations.products[i]
.symmetric_contractions.contractions[j]
.weights[k][indices_weights, :, :]
.clone()
)
model.products[i].symmetric_contractions.contractions[j].weights[
k
] = torch.nn.Parameter(
model_foundations.products[i]
.symmetric_contractions.contractions[j]
.weights[k][indices_weights, :, :]
.clone()
)

model.products[i].linear.weight = torch.nn.Parameter(
Expand Down
Loading

0 comments on commit 6878c75

Please sign in to comment.