diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 26a53eac..c7a543a1 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -18,7 +18,7 @@ from mace.modules.utils import extract_invariant from mace.tools import torch_geometric, torch_tools, utils from mace.tools.compile import prepare -from mace.tools.finetuning_utils import extract_load +from mace.tools.scripts_utils import extract_load def get_model_dtype(model: torch.nn.Module) -> torch.dtype: diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c85b55dc..050e896a 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -9,6 +9,7 @@ import json import logging import os +from copy import deepcopy from pathlib import Path from typing import Optional @@ -16,6 +17,7 @@ import torch.distributed import torch.nn.functional from e3nn import o3 +from e3nn.util import jit from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim.swa_utils import SWALR, AveragedModel from torch_ema import ExponentialMovingAverage @@ -24,14 +26,17 @@ from mace import data, modules, tools from mace.calculators.foundations_models import mace_mp, mace_off from mace.tools import torch_geometric -from mace.tools.finetuning_utils import extract_config_mace_model, load_foundations +from mace.tools.finetuning_utils import load_foundations from mace.tools.scripts_utils import ( LRScheduler, + convert_to_json_format, create_error_table, + extract_config_mace_model, get_atomic_energies, get_config_type_weights, get_dataset_from_xyz, get_files_with_suffix, + print_git_commit, ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable @@ -72,7 +77,7 @@ def main() -> None: tools.set_default_dtype(args.default_dtype) device = tools.init_device(args.device) - + commit = print_git_commit() if args.foundation_model is not None: if args.foundation_model in ["small", "medium", "large"]: logging.info( @@ -559,15 +564,13 @@ def main() -> None: assert dipole_only is False, "swa for dipole fitting not implemented" swas.append(True) if args.start_swa is None: - args.start_swa = ( - args.max_num_epochs // 4 * 3 - ) # if not set start swa at 75% of training + args.start_swa = max(1, args.max_num_epochs // 4 * 3) else: if args.start_swa > args.max_num_epochs: logging.info( f"Start swa must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" ) - args.start_swa = args.max_num_epochs // 4 * 3 + args.start_swa = max(1, args.max_num_epochs // 4 * 3) logging.info(f"Setting start swa to {args.start_swa}") if args.loss == "forces_only": logging.info("Can not select swa with forces only loss.") @@ -786,11 +789,42 @@ def main() -> None: if args.save_cpu: model = model.to("cpu") torch.save(model, model_path) - + extra_files = { + "commit.txt": commit.encode("utf-8"), + "config.yaml": json.dumps( + convert_to_json_format(extract_config_mace_model(model)) + ), + } if swa_eval: torch.save(model, Path(args.model_dir) / (args.name + "_swa.model")) + try: + path_complied = Path(args.model_dir) / ( + args.name + "_swa_compiled.model" + ) + logging.info(f"Compiling model, saving metadata {path_complied}") + model_compiled = jit.compile(deepcopy(model)) + torch.jit.save( + model_compiled, + path_complied, + _extra_files=extra_files, + ) + except Exception as e: # pylint: disable=W0703 + pass else: torch.save(model, Path(args.model_dir) / (args.name + ".model")) + try: + path_complied = Path(args.model_dir) / ( + args.name + "_compiled.model" + ) + logging.info(f"Compiling model, saving metadata to {path_complied}") + model_compiled = jit.compile(deepcopy(model)) + torch.jit.save( + model_compiled, + path_complied, + _extra_files=extra_files, + ) + except Exception as e: # pylint: disable=W070344 + pass if args.distributed: torch.distributed.barrier() diff --git a/mace/data/utils.py b/mace/data/utils.py index 4988859a..a6153665 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -208,21 +208,32 @@ def load_from_xyz( ) energy_key = "REF_energy" for atoms in atoms_list: - atoms.info["REF_energy"] = atoms.get_potential_energy() + try: + atoms.info["REF_energy"] = atoms.get_potential_energy() + except Exception as e: # pylint: disable=W0703 + logging.warning(f"Failed to extract energy: {e}") + atoms.info["REF_energy"] = None if forces_key == "forces": logging.info( "Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to 'REF_forces'" ) forces_key = "REF_forces" for atoms in atoms_list: - atoms.info["REF_forces"] = atoms.get_forces() + try: + atoms.info["REF_forces"] = atoms.get_forces() + except Exception as e: # pylint: disable=W0703 + logging.warning(f"Failed to extract forces: {e}") + atoms.info["REF_forces"] = None if stress_key == "stress": logging.info( "Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to 'REF_stress'" ) stress_key = "REF_stress" for atoms in atoms_list: - atoms.info["REF_stress"] = atoms.get_stress() + try: + atoms.info["REF_stress"] = atoms.get_stress() + except Exception as e: # pylint: disable=W0703 + atoms.info["REF_stress"] = None if not isinstance(atoms_list, list): atoms_list = [atoms_list] diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index ded2ed79..80375590 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -1,7 +1,7 @@ from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState -from .finetuning_utils import extract_load, load_foundations +from .finetuning_utils import load_foundations from .torch_tools import ( TensorDict, cartesian_to_spherical, @@ -66,6 +66,5 @@ "voigt_to_matrix", "init_wandb", "load_foundations", - "extract_load", "build_preprocess_arg_parser", ] diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index fe4c2d63..97cae96a 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -1,70 +1,8 @@ -from typing import Any, Dict - import torch -from e3nn import o3 from mace.tools.utils import AtomicNumberTable -def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: - def radial_to_name(radial_type): - if radial_type == "BesselBasis": - return "bessel" - if radial_type == "GaussianBasis": - return "gaussian" - if radial_type == "ChebychevBasis": - return "chebyshev" - return radial_type - - def radial_to_transform(radial): - if not hasattr(radial, "distance_transform"): - return None - if radial.distance_transform.__class__.__name__ == "AgnesiTransform": - return "Agnesi" - if radial.distance_transform.__class__.__name__ == "SoftTransform": - return "Soft" - return radial.distance_transform.__class__.__name__ - - config = { - "r_max": model.r_max.item(), - "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), - "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), - "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access - "interaction_cls": model.interactions[-1].__class__, - "interaction_cls_first": model.interactions[0].__class__, - "num_interactions": model.num_interactions.item(), - "num_elements": len(model.atomic_numbers), - "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), - "MLP_irreps": o3.Irreps(str(model.readouts[-1].hidden_irreps)), - "gate": model.readouts[-1] # pylint: disable=protected-access - .non_linearity._modules["acts"][0] - .f, - "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), - "avg_num_neighbors": model.interactions[0].avg_num_neighbors, - "atomic_numbers": model.atomic_numbers, - "correlation": len( - model.products[0].symmetric_contractions.contractions[0].weights - ) - + 1, - "radial_type": radial_to_name( - model.radial_embedding.bessel_fn.__class__.__name__ - ), - "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], - "pair_repulsion": hasattr(model, "pair_repulsion_fn"), - "distance_transform": radial_to_transform(model.radial_embedding), - "atomic_inter_scale": model.scale_shift.scale.item(), - "atomic_inter_shift": model.scale_shift.shift.item(), - } - return config - - -def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: - model = torch.load(f=f, map_location=map_location) - model_copy = model.__class__(**extract_config_mace_model(model)) - model_copy.load_state_dict(model.state_dict()) - return model_copy.to(map_location) - - def load_foundations( model: torch.nn.Module, model_foundations: torch.nn.Module, @@ -112,24 +50,24 @@ def load_foundations( for j in range(4): # Assuming 4 layers in conv_tp_weights, layer_name = f"layer{j}" if j == 0: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ) - .weight[:num_radial, :] - .clone() + getattr( + model.interactions[i].conv_tp_weights, layer_name + ).weight = torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, ) + .weight[:num_radial, :] + .clone() ) else: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ).weight.clone() - ) + getattr( + model.interactions[i].conv_tp_weights, layer_name + ).weight = torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() ) model.interactions[i].linear.weight = torch.nn.Parameter( @@ -167,23 +105,23 @@ def load_foundations( for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[j].weights_max = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) + model.products[i].symmetric_contractions.contractions[ + j + ].weights_max = torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights_max[indices_weights, :, :] + .clone() ) for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[k] = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] - .clone() - ) + model.products[i].symmetric_contractions.contractions[j].weights[ + k + ] = torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() ) model.products[i].linear.weight = torch.nn.Parameter( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index b1e71fd7..e72bfbb3 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -6,12 +6,15 @@ import ast import dataclasses +import json import logging import os -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple +import numpy as np import torch import torch.distributed +from e3nn import o3 from prettytable import PrettyTable from mace import data, modules @@ -117,6 +120,154 @@ def get_config_type_weights(ct_weights): return config_type_weights +def print_git_commit(): + try: + import git + + repo = git.Repo(search_parent_directories=True) + commit = repo.head.commit.hexsha + logging.info(f"Current Git commit: {commit}") + return commit + except Exception as e: # pylint: disable=W0703 + logging.info(f"Error accessing Git repository: {e}") + return None + + +def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: + if model.__class__.__name__ != "ScaleShiftMACE": + return {"error": "Model is not a ScaleShiftMACE model"} + + def radial_to_name(radial_type): + if radial_type == "BesselBasis": + return "bessel" + if radial_type == "GaussianBasis": + return "gaussian" + if radial_type == "ChebychevBasis": + return "chebyshev" + return radial_type + + def radial_to_transform(radial): + if not hasattr(radial, "distance_transform"): + return None + if radial.distance_transform.__class__.__name__ == "AgnesiTransform": + return "Agnesi" + if radial.distance_transform.__class__.__name__ == "SoftTransform": + return "Soft" + return radial.distance_transform.__class__.__name__ + + config = { + "r_max": model.r_max.item(), + "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), + "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), + "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access + "interaction_cls": model.interactions[-1].__class__, + "interaction_cls_first": model.interactions[0].__class__, + "num_interactions": model.num_interactions.item(), + "num_elements": len(model.atomic_numbers), + "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), + "MLP_irreps": o3.Irreps(str(model.readouts[-1].hidden_irreps)), + "gate": model.readouts[-1] # pylint: disable=protected-access + .non_linearity._modules["acts"][0] + .f, + "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), + "avg_num_neighbors": model.interactions[0].avg_num_neighbors, + "atomic_numbers": model.atomic_numbers, + "correlation": len( + model.products[0].symmetric_contractions.contractions[0].weights + ) + + 1, + "radial_type": radial_to_name( + model.radial_embedding.bessel_fn.__class__.__name__ + ), + "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], + "pair_repulsion": hasattr(model, "pair_repulsion_fn"), + "distance_transform": radial_to_transform(model.radial_embedding), + "atomic_inter_scale": model.scale_shift.scale.item(), + "atomic_inter_shift": model.scale_shift.shift.item(), + } + return config + + +def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: + model = torch.load(f=f, map_location=map_location) + model_copy = model.__class__(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + return model_copy.to(map_location) + + +def convert_to_json_format(dict_input): + for key, value in dict_input.items(): + if isinstance(value, (np.ndarray, torch.Tensor)): + dict_input[key] = value.tolist() + # # check if the value is a class and convert it to a string + elif hasattr(value, "__class__"): + dict_input[key] = str(value) + return dict_input + + +def convert_from_json_format(dict_input): + dict_output = dict_input.copy() + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output[ + "interaction_cls" + ] = modules.blocks.RealAgnosticResidualInteractionBlock + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output[ + "interaction_cls_first" + ] = modules.blocks.RealAgnosticResidualInteractionBlock + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output[ + "interaction_cls_first" + ] = modules.blocks.RealAgnosticInteractionBlock + dict_output["r_max"] = float(dict_input["r_max"]) + dict_output["num_bessel"] = int(dict_input["num_bessel"]) + dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) + dict_output["max_ell"] = int(dict_input["max_ell"]) + dict_output["num_interactions"] = int(dict_input["num_interactions"]) + dict_output["num_elements"] = int(dict_input["num_elements"]) + dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) + dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) + dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) + dict_output["gate"] = torch.nn.functional.silu + dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) + dict_output["atomic_numbers"] = dict_input["atomic_numbers"] + dict_output["correlation"] = int(dict_input["correlation"]) + dict_output["radial_type"] = dict_input["radial_type"] + dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) + dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) + dict_output["distance_transform"] = dict_input["distance_transform"] + dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) + dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) + + return dict_output + + +def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: + extra_files_extract = {"commit.txt": None, "config.json": None} + model_jit_load = torch.jit.load( + f, _extra_files=extra_files_extract, map_location=map_location + ) + model_load_yaml = modules.ScaleShiftMACE( + **convert_from_json_format(json.loads(extra_files_extract["config.json"])) + ) + model_load_yaml.load_state_dict(model_jit_load.state_dict()) + return model_load_yaml.to(map_location) + + def get_atomic_energies(E0s, train_collection, z_table) -> dict: if E0s is not None: logging.info( diff --git a/tests/test_foundations.py b/tests/test_foundations.py index cf52dd99..69963c67 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -9,7 +9,8 @@ from mace import data, modules, tools from mace.calculators import mace_mp, mace_off from mace.tools import torch_geometric -from mace.tools.finetuning_utils import extract_config_mace_model, load_foundations +from mace.tools.finetuning_utils import load_foundations +from mace.tools.scripts_utils import extract_config_mace_model from mace.tools.utils import AtomicNumberTable torch.set_default_dtype(torch.float64)