From 7255360810627272ad1a6d8f7f505d8d834cb091 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 20 Aug 2024 17:29:09 -0400 Subject: [PATCH] lint merged changes --- mace/cli/run_train.py | 4 +--- mace/modules/models.py | 4 +++- mace/tools/scripts_utils.py | 31 +++++++++++-------------------- mace/tools/train.py | 10 ++-------- 4 files changed, 17 insertions(+), 32 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 76731f0a..0bc01817 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -764,9 +764,7 @@ def run(args: argparse.Namespace) -> None: "config.yaml": json.dumps(convert_to_json_format(extract_config_mace_model(model))), } if swa_eval: - torch.save( - model, Path(args.model_dir) / (args.name + "_stagetwo.model") - ) + torch.save(model, Path(args.model_dir) / (args.name + "_stagetwo.model")) try: path_complied = Path(args.model_dir) / (args.name + "_stagetwo_compiled.model") logging.info(f"Compiling model, saving metadata {path_complied}") diff --git a/mace/modules/models.py b/mace/modules/models.py index 50e9c5fa..c429eb43 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -126,7 +126,9 @@ def __init__( self.readouts.append(LinearReadoutBlock(hidden_irreps)) for i in range(num_interactions - 1): - hidden_irreps_out = str(hidden_irreps[0]) if i == num_interactions - 2 else hidden_irreps # Select only scalars for last layer + hidden_irreps_out = ( + str(hidden_irreps[0]) if i == num_interactions - 2 else hidden_irreps + ) # Select only scalars for last layer inter = interaction_cls( node_attrs_irreps=node_attr_irreps, node_feats_irreps=hidden_irreps, diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index d232d7e7..cb335af2 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -246,19 +246,16 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: atomic_energies_dict = data.compute_average_E0s(train_collection, z_table) except Exception as e: raise RuntimeError(f"Could not compute average E0s if no training xyz given, error {e} occured") from e + elif E0s.endswith(".json"): + logging.info(f"Loading atomic energies from {E0s}") + with open(E0s, encoding="utf-8") as f: + atomic_energies_dict = json.load(f) else: - if E0s.endswith(".json"): - logging.info(f"Loading atomic energies from {E0s}") - with open(E0s, "r", encoding="utf-8") as f: - atomic_energies_dict = json.load(f) - else: - try: - atomic_energies_dict = ast.literal_eval(E0s) - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError( - f"E0s specified invalidly, error {e} occured" - ) from e + try: + atomic_energies_dict = ast.literal_eval(E0s) + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError(f"E0s specified invalidly, error {e} occured") from e else: raise RuntimeError("E0s not found in training file and not specified in command line") return atomic_energies_dict @@ -495,10 +492,7 @@ def create_error_table( f"{metrics['rmse_virials'] * 1000:.1f}", ] ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_stress"] is not None - ): + elif table_type == "PerAtomMAEstressvirials" and metrics["mae_stress"] is not None: table.add_row( [ name, @@ -508,10 +502,7 @@ def create_error_table( f"{metrics['mae_stress'] * 1000:.1f}", ] ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_virials"] is not None - ): + elif table_type == "PerAtomMAEstressvirials" and metrics["mae_virials"] is not None: table.add_row( [ name, diff --git a/mace/tools/train.py b/mace/tools/train.py index e487f52d..6ac61bef 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -62,20 +62,14 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): logging.info( f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" ) - elif ( - log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_stress_per_atom"] is not None - ): + elif log_errors == "PerAtomMAEstressvirials" and eval_metrics["mae_stress_per_atom"] is not None: error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 error_stress = eval_metrics["mae_stress"] * 1e3 logging.info( f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A, MAE_stress={error_stress:.1f} meV / A^3" ) - elif ( - log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_virials_per_atom"] is not None - ): + elif log_errors == "PerAtomMAEstressvirials" and eval_metrics["mae_virials_per_atom"] is not None: error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 error_virials = eval_metrics["mae_virials"] * 1e3