Skip to content

Commit

Permalink
Merge pull request #548 from ACEsuit/develop
Browse files Browse the repository at this point in the history
Fix MAE table for universal loss
  • Loading branch information
ilyes319 authored Aug 12, 2024
2 parents 126f490 + 81c02b4 commit 575af01
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 15 deletions.
9 changes: 7 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,10 @@ def run(args: argparse.Namespace) -> None:
if args.loss in ("stress", "virials", "huber", "universal"):
compute_virials = True
args.compute_stress = True
args.error_table = "PerAtomRMSEstressvirials"
if "MAE" in args.error_table:
args.error_table = "PerAtomMAEstressvirials"
else:
args.error_table = "PerAtomRMSEstressvirials"

output_args = {
"energy": compute_energy,
Expand Down Expand Up @@ -821,7 +824,9 @@ def run(args: argparse.Namespace) -> None:
),
}
if swa_eval:
torch.save(model, Path(args.model_dir) / (args.name + "_stagetwo.model"))
torch.save(
model, Path(args.model_dir) / (args.name + "_stagetwo.model")
)
try:
path_complied = Path(args.model_dir) / (
args.name + "_stagetwo_compiled.model"
Expand Down
29 changes: 21 additions & 8 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"PerAtomRMSE",
"TotalRMSE",
"PerAtomRMSEstressvirials",
"PerAtomMAEstressvirials",
"PerAtomMAE",
"TotalMAE",
"DipoleRMSE",
Expand Down Expand Up @@ -388,7 +389,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--forces_weight", help="weight of forces loss", type=float, default=100.0
)
parser.add_argument(
"--swa_forces_weight","--stage_two_forces_weight",
"--swa_forces_weight",
"--stage_two_forces_weight",
help="weight of forces loss after starting Stage Two (previously called swa)",
type=float,
default=100.0,
Expand All @@ -398,7 +400,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--energy_weight", help="weight of energy loss", type=float, default=1.0
)
parser.add_argument(
"--swa_energy_weight","--stage_two_energy_weight",
"--swa_energy_weight",
"--stage_two_energy_weight",
help="weight of energy loss after starting Stage Two (previously called swa)",
type=float,
default=1000.0,
Expand All @@ -408,7 +411,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--virials_weight", help="weight of virials loss", type=float, default=1.0
)
parser.add_argument(
"--swa_virials_weight", "--stage_two_virials_weight",
"--swa_virials_weight",
"--stage_two_virials_weight",
help="weight of virials loss after starting Stage Two (previously called swa)",
type=float,
default=10.0,
Expand All @@ -418,7 +422,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--stress_weight", help="weight of virials loss", type=float, default=1.0
)
parser.add_argument(
"--swa_stress_weight", "--stage_two_stress_weight",
"--swa_stress_weight",
"--stage_two_stress_weight",
help="weight of stress loss after starting Stage Two (previously called swa)",
type=float,
default=10.0,
Expand All @@ -428,7 +433,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--dipole_weight", help="weight of dipoles loss", type=float, default=1.0
)
parser.add_argument(
"--swa_dipole_weight","--stage_two_dipole_weight",
"--swa_dipole_weight",
"--stage_two_dipole_weight",
help="weight of dipoles after starting Stage Two (previously called swa)",
type=float,
default=1.0,
Expand Down Expand Up @@ -467,7 +473,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--lr", help="Learning rate of optimizer", type=float, default=0.01
)
parser.add_argument(
"--swa_lr", "--stage_two_lr", help="Learning rate of optimizer in Stage Two (previously called swa)", type=float, default=1e-3, dest="swa_lr"
"--swa_lr",
"--stage_two_lr",
help="Learning rate of optimizer in Stage Two (previously called swa)",
type=float,
default=1e-3,
dest="swa_lr",
)
parser.add_argument(
"--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7
Expand All @@ -494,14 +505,16 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
default=0.9993,
)
parser.add_argument(
"--swa", "--stage_two",
"--swa",
"--stage_two",
help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them",
action="store_true",
default=False,
dest="swa",
)
parser.add_argument(
"--start_swa","--start_stage_two",
"--start_swa",
"--start_stage_two",
help="Number of epochs before changing to Stage Two loss weights",
type=int,
default=None,
Expand Down
51 changes: 46 additions & 5 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,18 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict:
f"Could not compute average E0s if no training xyz given, error {e} occured"
) from e
else:
try:
atomic_energies_dict = ast.literal_eval(E0s)
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(f"E0s specified invalidly, error {e} occured") from e
if E0s.endswith(".json"):
logging.info(f"Loading atomic energies from {E0s}")
with open(E0s, "r", encoding="utf-8") as f:
atomic_energies_dict = json.load(f)
else:
try:
atomic_energies_dict = ast.literal_eval(E0s)
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(
f"E0s specified invalidly, error {e} occured"
) from e
else:
raise RuntimeError(
"E0s not found in training file and not specified in command line"
Expand Down Expand Up @@ -454,6 +461,14 @@ def create_error_table(
"relative F RMSE %",
"RMSE Stress (Virials) / meV / A (A^3)",
]
elif table_type == "PerAtomMAEstressvirials":
table.field_names = [
"config_type",
"MAE E / meV / atom",
"MAE F / meV / A",
"relative F MAE %",
"MAE Stress (Virials) / meV / A (A^3)",
]
elif table_type == "TotalMAE":
table.field_names = [
"config_type",
Expand Down Expand Up @@ -558,6 +573,32 @@ def create_error_table(
f"{metrics['rmse_virials'] * 1000:.1f}",
]
)
elif (
table_type == "PerAtomMAEstressvirials"
and metrics["mae_stress"] is not None
):
table.add_row(
[
name,
f"{metrics['mae_e_per_atom'] * 1000:.1f}",
f"{metrics['mae_f'] * 1000:.1f}",
f"{metrics['rel_mae_f']:.2f}",
f"{metrics['mae_stress'] * 1000:.1f}",
]
)
elif (
table_type == "PerAtomMAEstressvirials"
and metrics["mae_virials"] is not None
):
table.add_row(
[
name,
f"{metrics['mae_e_per_atom'] * 1000:.1f}",
f"{metrics['mae_f'] * 1000:.1f}",
f"{metrics['rel_mae_f']:.2f}",
f"{metrics['mae_virials'] * 1000:.1f}",
]
)
elif table_type == "TotalMAE":
table.add_row(
[
Expand Down
20 changes: 20 additions & 0 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,26 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None):
logging.info(
f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV"
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_stress_per_atom"] is not None
):
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_stress = eval_metrics["mae_stress"] * 1e3
logging.info(
f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3"
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_virials_per_atom"] is not None
):
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_virials = eval_metrics["mae_virials"] * 1e3
logging.info(
f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_virials={error_virials:.1f} meV"
)
elif log_errors == "TotalRMSE":
error_e = eval_metrics["rmse_e"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
Expand Down

0 comments on commit 575af01

Please sign in to comment.