Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/ACEsuit/mace into ruff-linter
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Aug 20, 2024
2 parents 86dcd6d + 575af01 commit 2b981aa
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 7 deletions.
9 changes: 7 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ def run(args: argparse.Namespace) -> None:
if args.loss in ("stress", "virials", "huber", "universal"):
compute_virials = True
args.compute_stress = True
args.error_table = "PerAtomRMSEstressvirials"
if "MAE" in args.error_table:
args.error_table = "PerAtomMAEstressvirials"
else:
args.error_table = "PerAtomRMSEstressvirials"

output_args = {
"energy": compute_energy,
Expand Down Expand Up @@ -761,7 +764,9 @@ def run(args: argparse.Namespace) -> None:
"config.yaml": json.dumps(convert_to_json_format(extract_config_mace_model(model))),
}
if swa_eval:
torch.save(model, Path(args.model_dir) / (args.name + "_stagetwo.model"))
torch.save(
model, Path(args.model_dir) / (args.name + "_stagetwo.model")
)
try:
path_complied = Path(args.model_dir) / (args.name + "_stagetwo_compiled.model")
logging.info(f"Compiling model, saving metadata {path_complied}")
Expand Down
1 change: 1 addition & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"PerAtomRMSE",
"TotalRMSE",
"PerAtomRMSEstressvirials",
"PerAtomMAEstressvirials",
"PerAtomMAE",
"TotalMAE",
"DipoleRMSE",
Expand Down
51 changes: 46 additions & 5 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,18 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict:
except Exception as e:
raise RuntimeError(f"Could not compute average E0s if no training xyz given, error {e} occured") from e
else:
try:
atomic_energies_dict = ast.literal_eval(E0s)
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(f"E0s specified invalidly, error {e} occured") from e
if E0s.endswith(".json"):
logging.info(f"Loading atomic energies from {E0s}")
with open(E0s, "r", encoding="utf-8") as f:
atomic_energies_dict = json.load(f)
else:
try:
atomic_energies_dict = ast.literal_eval(E0s)
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(
f"E0s specified invalidly, error {e} occured"
) from e
else:
raise RuntimeError("E0s not found in training file and not specified in command line")
return atomic_energies_dict
Expand Down Expand Up @@ -384,6 +391,14 @@ def create_error_table(
"relative F RMSE %",
"RMSE Stress (Virials) / meV / A (A^3)",
]
elif table_type == "PerAtomMAEstressvirials":
table.field_names = [
"config_type",
"MAE E / meV / atom",
"MAE F / meV / A",
"relative F MAE %",
"MAE Stress (Virials) / meV / A (A^3)",
]
elif table_type == "TotalMAE":
table.field_names = [
"config_type",
Expand Down Expand Up @@ -480,6 +495,32 @@ def create_error_table(
f"{metrics['rmse_virials'] * 1000:.1f}",
]
)
elif (
table_type == "PerAtomMAEstressvirials"
and metrics["mae_stress"] is not None
):
table.add_row(
[
name,
f"{metrics['mae_e_per_atom'] * 1000:.1f}",
f"{metrics['mae_f'] * 1000:.1f}",
f"{metrics['rel_mae_f']:.2f}",
f"{metrics['mae_stress'] * 1000:.1f}",
]
)
elif (
table_type == "PerAtomMAEstressvirials"
and metrics["mae_virials"] is not None
):
table.add_row(
[
name,
f"{metrics['mae_e_per_atom'] * 1000:.1f}",
f"{metrics['mae_f'] * 1000:.1f}",
f"{metrics['rel_mae_f']:.2f}",
f"{metrics['mae_virials'] * 1000:.1f}",
]
)
elif table_type == "TotalMAE":
table.add_row(
[
Expand Down
20 changes: 20 additions & 0 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None):
logging.info(
f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV"
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_stress_per_atom"] is not None
):
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_stress = eval_metrics["mae_stress"] * 1e3
logging.info(
f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3"
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_virials_per_atom"] is not None
):
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_virials = eval_metrics["mae_virials"] * 1e3
logging.info(
f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_virials={error_virials:.1f} meV"
)
elif log_errors == "TotalRMSE":
error_e = eval_metrics["rmse_e"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
Expand Down

0 comments on commit 2b981aa

Please sign in to comment.