Skip to content

Commit

Permalink
lint merged changes
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Aug 20, 2024
1 parent 2b981aa commit 7255360
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 32 deletions.
4 changes: 1 addition & 3 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,7 @@ def run(args: argparse.Namespace) -> None:
"config.yaml": json.dumps(convert_to_json_format(extract_config_mace_model(model))),
}
if swa_eval:
torch.save(
model, Path(args.model_dir) / (args.name + "_stagetwo.model")
)
torch.save(model, Path(args.model_dir) / (args.name + "_stagetwo.model"))
try:
path_complied = Path(args.model_dir) / (args.name + "_stagetwo_compiled.model")
logging.info(f"Compiling model, saving metadata {path_complied}")
Expand Down
4 changes: 3 additions & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def __init__(
self.readouts.append(LinearReadoutBlock(hidden_irreps))

for i in range(num_interactions - 1):
hidden_irreps_out = str(hidden_irreps[0]) if i == num_interactions - 2 else hidden_irreps # Select only scalars for last layer
hidden_irreps_out = (
str(hidden_irreps[0]) if i == num_interactions - 2 else hidden_irreps
) # Select only scalars for last layer
inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
Expand Down
31 changes: 11 additions & 20 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,19 +246,16 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict:
atomic_energies_dict = data.compute_average_E0s(train_collection, z_table)
except Exception as e:
raise RuntimeError(f"Could not compute average E0s if no training xyz given, error {e} occured") from e
elif E0s.endswith(".json"):
logging.info(f"Loading atomic energies from {E0s}")
with open(E0s, encoding="utf-8") as f:
atomic_energies_dict = json.load(f)
else:
if E0s.endswith(".json"):
logging.info(f"Loading atomic energies from {E0s}")
with open(E0s, "r", encoding="utf-8") as f:
atomic_energies_dict = json.load(f)
else:
try:
atomic_energies_dict = ast.literal_eval(E0s)
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(
f"E0s specified invalidly, error {e} occured"
) from e
try:
atomic_energies_dict = ast.literal_eval(E0s)
assert isinstance(atomic_energies_dict, dict)
except Exception as e:
raise RuntimeError(f"E0s specified invalidly, error {e} occured") from e
else:
raise RuntimeError("E0s not found in training file and not specified in command line")
return atomic_energies_dict
Expand Down Expand Up @@ -495,10 +492,7 @@ def create_error_table(
f"{metrics['rmse_virials'] * 1000:.1f}",
]
)
elif (
table_type == "PerAtomMAEstressvirials"
and metrics["mae_stress"] is not None
):
elif table_type == "PerAtomMAEstressvirials" and metrics["mae_stress"] is not None:
table.add_row(
[
name,
Expand All @@ -508,10 +502,7 @@ def create_error_table(
f"{metrics['mae_stress'] * 1000:.1f}",
]
)
elif (
table_type == "PerAtomMAEstressvirials"
and metrics["mae_virials"] is not None
):
elif table_type == "PerAtomMAEstressvirials" and metrics["mae_virials"] is not None:
table.add_row(
[
name,
Expand Down
10 changes: 2 additions & 8 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,14 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None):
logging.info(
f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV"
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_stress_per_atom"] is not None
):
elif log_errors == "PerAtomMAEstressvirials" and eval_metrics["mae_stress_per_atom"] is not None:
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_stress = eval_metrics["mae_stress"] * 1e3
logging.info(
f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3"
)
elif (
log_errors == "PerAtomMAEstressvirials"
and eval_metrics["mae_virials_per_atom"] is not None
):
elif log_errors == "PerAtomMAEstressvirials" and eval_metrics["mae_virials_per_atom"] is not None:
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
error_virials = eval_metrics["mae_virials"] * 1e3
Expand Down

0 comments on commit 7255360

Please sign in to comment.