diff --git a/CHANGELOG.md b/CHANGELOG.md index 556e70ab..25c2d502 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ Most recent change on the bottom. - alternate neighborlist support enabled with `NEQUIP_NL` environment variable, which can be set to `ase` (default), `matscipy` or `vesin` - Allow `n_train` and `n_val` to be specified as percentages of datasets. - Only attempt training restart if `trainer.pth` file present (prevents unnecessary crashes due to file-not-found errors in some cases) +- Stratified metrics now possible; stratified by reference values in percent or raw units, or by error population. ### Changed - [Breaking] `NEQUIP_MATSCIPY_NL` environment variable no longer supported diff --git a/CITATION.cff b/CITATION.cff index 3f107357..bdbe3022 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -2,29 +2,50 @@ cff-version: "1.2.0" message: "If you use this software, please cite our article." authors: - family-names: Batzner - given-names: Simon + given-names: Simon - family-names: Musaelian - given-names: Albert + given-names: Albert - family-names: Sun - given-names: Lixin + given-names: Lixin - family-names: Geiger - given-names: Mario + given-names: Mario - family-names: Mailoa - given-names: Jonathan P. + given-names: Jonathan P. - family-names: Kornbluth - given-names: Mordechai + given-names: Mordechai - family-names: Molinari - given-names: Nicola + given-names: Nicola - family-names: Smidt - given-names: Tess E. + given-names: Tess E. - family-names: Kozinsky - given-names: Boris + given-names: Boris doi: 10.1038/s41467-022-29939-5 -date-published: 2022-05-04 -issn: 2041-1723 -journal: Nature Communications -start: 2453 -title: "E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials" -type: article -url: "https://www.nature.com/articles/s41467-022-29939-5" -volume: 13 +preferred-citation: + authors: + - family-names: Batzner + given-names: Simon + - family-names: Musaelian + given-names: Albert + - family-names: Sun + given-names: Lixin + - family-names: Geiger + given-names: Mario + - family-names: Mailoa + given-names: Jonathan P. + - family-names: Kornbluth + given-names: Mordechai + - family-names: Molinari + given-names: Nicola + - family-names: Smidt + given-names: Tess E. + - family-names: Kozinsky + given-names: Boris + doi: 10.1038/s41467-022-29939-5 + date-published: 2022-05-04 + issn: 2041-1723 + journal: Nature Communications + start: 2453 + title: "E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials" + type: article + url: "https://www.nature.com/articles/s41467-022-29939-5" + volume: 13 diff --git a/configs/full.yaml b/configs/full.yaml index 91684a36..371b0998 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -281,12 +281,23 @@ metrics_components: - - forces - rmse - PerSpecies: True - report_per_component: False + report_per_component: False - - total_energy - mae - - total_energy - mae - PerAtom: True # if true, energy is normalized by the number of atoms +# we can also output errors stratified by the reference value ranges (in percent or absolute values), or by the error populations in percent: + - - total_energy + - mae + - stratify: 10%_range # stratify by range (in reference energies per atom), in increments of 10% (i.e. errors for first 10% lowest reference values, next 10% etc) + PerAtom: True + - - forces + - rmse + - stratify: 10%_population # stratify by population (in forces errors per atom), in increments of 10% (i.e. errors for first 10% lowest errors, next 10% etc) + - - stress + - mae + - stratify: 0.001 # stratify by absolute value (in reference stresses), in increments of 0.001 # optimizer, may be any optimizer defined in torch.optim # the name `optimizer_name`is case sensitive diff --git a/nequip/__init__.py b/nequip/__init__.py index ce145b41..fdfce931 100644 --- a/nequip/__init__.py +++ b/nequip/__init__.py @@ -1,3 +1,4 @@ +import os import sys from ._version import __version__ # noqa: F401 @@ -16,7 +17,10 @@ ), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found" # warn if using 1.13* or 2.0.* -if packaging.version.parse("1.13.0") <= torch_version: +if ( + packaging.version.parse("1.13.0") <= torch_version + and int(os.environ.get("PYTORCH_VERSION_WARNING", 1)) != 0 +): warnings.warn( f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue." ) diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py index b40c3a8a..8054ab5b 100644 --- a/nequip/scripts/evaluate.py +++ b/nequip/scripts/evaluate.py @@ -455,11 +455,13 @@ def main(args=None, running_as_script: bool = True): if do_metrics: logger.info("\n--- Final result: ---") - logger.critical( + logger.info( "\n".join( - f"{k:>20s} = {v:< 20f}" + f"{k:>30s} = {v:< 30f}" for k, v in metrics.flatten_metrics( - metrics.current_result(), + metrics.current_result( + verbose=True + ), # verbose output about strata on final call type_names=dataset.type_mapper.type_names, )[0].items() ) diff --git a/nequip/train/metrics.py b/nequip/train/metrics.py index 103434d5..9bac8c81 100644 --- a/nequip/train/metrics.py +++ b/nequip/train/metrics.py @@ -5,6 +5,7 @@ import yaml import torch +import numpy as np from nequip.data import AtomicDataDict from torch_runstats import RunningStats, Reduction @@ -64,10 +65,17 @@ def __init__( self.params = {} self.funcs = {} self.kwargs = {} + self.stratified_stats = ( + {} + ) # need to be stored separately, as needs all data labels at once + for component in components: key, reduction, params = Metrics.parse(component) + params["stratify"] = params.get( + "stratify", False + ) # can be either 'XX%_range', 'XX%_population' or int/float (raw unit for separation) params["PerSpecies"] = params.get("PerSpecies", False) params["PerAtom"] = params.get("PerAtom", False) @@ -75,17 +83,21 @@ def __init__( functional = params.get("functional", "L1Loss") - # default is to flatten the array - - if key not in self.running_stats: - self.running_stats[key] = {} + if key not in self.kwargs: self.funcs[key] = {} self.kwargs[key] = {} self.params[key] = {} + if key not in self.running_stats and not params.get("stratify", False): + self.running_stats[key] = {} # default is to flatten the array + + if key not in self.stratified_stats and params.get("stratify", False): + self.stratified_stats[key] = {} + # store for initialization kwargs = deepcopy(params) kwargs.pop("functional", "L1Loss") + kwargs.pop("stratify") kwargs.pop("PerSpecies") kwargs.pop("PerAtom") @@ -160,20 +172,26 @@ def __call__(self, pred: dict, ref: dict): ) _, params = self.params[key][param_hash] + stratify = params["stratify"] per_species = params["PerSpecies"] per_atom = params["PerAtom"] - # initialize the internal run_stat base on the error shape - if param_hash not in self.running_stats[key]: - self.running_stats[key][param_hash] = self.init_runstat( - params=kwargs, error=error - ) + if not stratify: + # initialize the internal run_stat base on the error shape + if param_hash not in self.running_stats[key]: + self.running_stats[key][param_hash] = self.init_runstat( + params=kwargs, error=error + ) - stat = self.running_stats[key][param_hash] + stat = self.running_stats[key][param_hash] params = {} if per_species: - # TO DO, this needs OneHot component. will need to be decoupled + if stratify: + raise NotImplementedError( + "Stratify is not implemented for per_species" + ) + # TODO, this needs OneHot component. will need to be decoupled params = { "accumulate_by": pred[AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1) } @@ -184,10 +202,28 @@ def __call__(self, pred: dict, ref: dict): else: error_N = error - if stat.dim == () and not per_species: - metrics[(key, param_hash)] = stat.accumulate_batch( - error_N.flatten(), **params - ) + if not stratify and stat.dim == () and not per_species: + error_N = error_N.flatten() + + if ( # just need error and ref value, note that forces are not stratified by xyz + stratify # norm (just raw x, y, z values used) + ): + if param_hash not in self.stratified_stats[key]: + self.stratified_stats[key][param_hash] = { + "error": error_N, + "ref_val": ref[key], + } + else: + self.stratified_stats[key][param_hash]["error"] = torch.cat( + (self.stratified_stats[key][param_hash]["error"], error_N) + ) + self.stratified_stats[key][param_hash]["ref_val"] = torch.cat( + ( + self.stratified_stats[key][param_hash]["ref_val"], + ref[key], + ) + ) + else: metrics[(key, param_hash)] = stat.accumulate_batch( error_N, **params @@ -205,12 +241,159 @@ def to(self, device): for stat in stats.values(): stat.to(device=device) - def current_result(self): + def current_result(self, verbose=False): + """ + Return the current result of the metrics. + + Args: + verbose (bool): + If True, prints information about stratified metrics (i.e. ranges). + Default: False + """ metrics = {} for key, stats in self.running_stats.items(): - for reduction, stat in stats.items(): - metrics[(key, reduction)] = stat.current_result() + for param_hash, stat in stats.items(): + metrics[(key, param_hash)] = stat.current_result() + + for key, stats in self.stratified_stats.items(): + for ( + param_hash, + stratified_stat_dict, + ) in stats.items(): # compute the stratified error: + reduction, params = self.params[key][param_hash] + # flatten in case has dim > 1 (force, stress): + errors = stratified_stat_dict["error"].flatten() + ref_vals = stratified_stat_dict["ref_val"].flatten() + stratified_metric_dict = {} + + if ( + isinstance(params.get("stratify"), str) + and "range" in params.get("stratify") + ) or isinstance( + params.get("stratify"), (int, float) + ): # stratify by range, + # either as percent string or raw unit: + min_max_range = (ref_vals.max() - ref_vals.min()).cpu().numpy() + if isinstance(params.get("stratify"), (int, float)): + range_separation = range_separation_str = params["stratify"] + else: + range_separation = ( + float(params["stratify"].strip("%_range")) / 100 + ) * min_max_range + range_separation_str = ( + f"{params['stratify'].strip('_range')} (= " + f"~{range_separation:.3f})" + ) + if verbose: + print( + f"Stratifying {key} errors by {key} range, in increments of " + f"{range_separation_str}, with min-max dataset range of {min_max_range:.3f}" + ) + + num_strata = np.ceil(min_max_range / range_separation).astype(int) + if isinstance(params.get("stratify"), str): + format = ( # .1% if 1/num_strata is not an integer, otherwise .0% (no decimal) + ".1%" + if not np.isclose( + (1 / num_strata) * 100, round((1 / num_strata) * 100) + ) + else ".0%" + ) + elif isinstance( + params.get("stratify"), float + ): # same decimal places as given + format = f".{str(params.get('stratify'))[::-1].find('.')}f" + else: + format = "" + + for i in range(num_strata): + if isinstance(params.get("stratify"), str): + stratum_key = ( + f"{i / num_strata:{format}}" + f"-{(i + 1) / num_strata:{format}}_range" + ) + else: + stratum_key = ( + f"{i * range_separation:{format}}" + f"-{(i + 1) * range_separation:{format}}_range" + ) + + mask = (ref_vals >= (i * range_separation) + ref_vals.min()) & ( + ref_vals < ((i + 1) * range_separation) + ref_vals.min() + ) + masked_errors = errors[mask] + if len(masked_errors) > 0: + if reduction in ["rms", "rmse"]: + stat = masked_errors.square().mean().sqrt() + elif reduction in ["mean", "mae"]: + stat = masked_errors.mean() + else: + raise NotImplementedError( + f"reduction {reduction} not implemented" + ) + + stratified_metric_dict[stratum_key] = stat + + else: + stratified_metric_dict[stratum_key] = torch.tensor( + float("nan") + ) + + elif isinstance( + params.get("stratify"), str + ) and "population" in params.get( + "stratify" + ): # stratify by population (given as percent string) + total_population = len(errors) + population_separation = ( + float(params["stratify"].strip("%_population")) / 100 + ) * total_population + if verbose: + print( + f"Stratifying {key} errors by population, in increments of " + f"{params['stratify'].strip('_population')} (= " + f"~{round(population_separation)} labels)" + ) + + num_strata = np.ceil( + total_population / population_separation + ).astype(int) + format = ( # .1% if 1/num_strata is not an integer, otherwise .0% (no decimal) + ".1%" + if not np.isclose( + (1 / num_strata) * 100, round((1 / num_strata) * 100) + ) + else ".0%" + ) + sorted_errors = torch.sort(errors, dim=0).values + + for i in range(num_strata): + stratum_key = f"{i/num_strata:{format}}-{(i+1)/num_strata:{format}}_population" + stratum_errors = sorted_errors[ + int(population_separation * i) : int( + population_separation * (i + 1) + ) + ] + if len(stratum_errors) > 0: + if reduction in ["rms", "rmse"]: + stat = stratum_errors.square().mean().sqrt() + elif reduction in ["mean", "mae"]: + stat = stratum_errors.mean() + else: + raise NotImplementedError( + f"reduction {reduction} not implemented" + ) + + stratified_metric_dict[stratum_key] = stat + + else: + stratified_metric_dict[stratum_key] = torch.tensor( + float("nan") + ) + + metrics[(key, param_hash)] = stratified_metric_dict + return metrics def flatten_metrics(self, metrics, type_names=None): @@ -230,34 +413,45 @@ def flatten_metrics(self, metrics, type_names=None): suffix = "/N" if per_atom else "" item_name = f"{short_name}{suffix}_{reduction}" - stat = self.running_stats[key][param_hash] per_species = params["PerSpecies"] + stratify = params["stratify"] - if per_species: - if stat.output_dim == tuple(): - if type_names is None: - type_names = [i for i in range(len(value))] - for id_ele, v in enumerate(value): - if type_names is not None: - flat_dict[f"{type_names[id_ele]}_{item_name}"] = v.item() - else: - flat_dict[f"{id_ele}_{item_name}"] = v.item() - - flat_dict[f"psavg_{item_name}"] = value.mean().item() - else: - for id_ele, vec in enumerate(value): - ele = type_names[id_ele] - for idx, v in enumerate(vec): - name = f"{ele}_{item_name}_{idx}" - flat_dict[name] = v.item() - skip_keys.append(name) + if stratify: # then value is a dict of {stratum_idx: value} + for stratum_idx, v in value.items(): + name = f"{stratum_idx}_{item_name}" + flat_dict[name] = v.item() + skip_keys.append(name) else: - if stat.output_dim == tuple(): - # a scalar - flat_dict[item_name] = value.item() + stat = self.running_stats[key][param_hash] + + if per_species: + stat = self.running_stats[key][param_hash] + if stat.output_dim == tuple(): + if type_names is None: + type_names = [i for i in range(len(value))] + for id_ele, v in enumerate(value): + if type_names is not None: + flat_dict[f"{type_names[id_ele]}_{item_name}"] = ( + v.item() + ) + else: + flat_dict[f"{id_ele}_{item_name}"] = v.item() + + flat_dict[f"psavg_{item_name}"] = value.mean().item() + else: + for id_ele, vec in enumerate(value): + ele = type_names[id_ele] + for idx, v in enumerate(vec): + name = f"{ele}_{item_name}_{idx}" + flat_dict[name] = v.item() + skip_keys.append(name) + else: - # a vector - for idx, v in enumerate(value.flatten()): - flat_dict[f"{item_name}_{idx}"] = v.item() + if stat.output_dim == tuple(): # a scalar + flat_dict[item_name] = value.item() + else: # a vector + for idx, v in enumerate(value.flatten()): + flat_dict[f"{item_name}_{idx}"] = v.item() + return flat_dict, skip_keys diff --git a/tests/integration/test_evaluate.py b/tests/integration/test_evaluate.py index 4a1388ed..71bfa575 100644 --- a/tests/integration/test_evaluate.py +++ b/tests/integration/test_evaluate.py @@ -57,7 +57,8 @@ def runit(params: dict): metrics = dict( [ tuple(e.strip() for e in line.split("=", 1)) - for line in retcode.stdout.decode().splitlines() + for line in retcode.stderr.decode().splitlines() + if " = " in line ] ) metrics = {k: float(v) for k, v in metrics.items()} @@ -97,12 +98,17 @@ def runit(params: dict): - - total_energy - mae - PerAtom: True + - - total_energy + - mae + - stratify: 10%_range + PerAtom: True """ ) ) expect_metrics = { "e_mae", "e/N_mae", + "10%-20%_range_e/N_mae", } else: # Write out a fancier metrics file @@ -121,6 +127,16 @@ def runit(params: dict): - - total_energy - mae - PerAtom: True + - - total_energy + - mae + - stratify: 10%_range + PerAtom: True + - - forces + - rmse + - stratify: 10%_population + - - total_energy + - mae + - stratify: 0.5 """ ) ) @@ -131,6 +147,9 @@ def runit(params: dict): "psavg_f_mae", "e_mae", "e/N_mae", + "10%-20%_range_e/N_mae", + "30%-40%_population_f_rmse", + "0.5-1.0_range_e_mae", }.union( { # For the PerSpecies @@ -161,22 +180,27 @@ def runit(params: dict): orig_atoms = ase.io.read(tmpdir + "/out-orig.xyz", index=":", format="extxyz") # check that we have the metrics - assert set(metrics.keys()) == expect_metrics + assert expect_metrics.issubset(set(metrics.keys())) # check metrics if builder == IdentityModel: true_identity: bool = true_config["default_dtype"] == true_config["model_dtype"] for metric, err in metrics.items(): - # see test_train.py for discussion - assert np.allclose( - err, - 0.0, - atol=( - 1e-8 - if true_identity - else (1e-2 if metric.startswith("e") else 1e-4) - ), - ), f"Metric `{metric}` wasn't zero!" + if not np.isnan(err): + # see test_train.py for discussion + assert np.allclose( + err, + 0.0, + atol=( + 1e-8 + if true_identity + else ( + 2e-2 + if any(i in metric for i in ["_e", "e_", "e/N"]) + else 1e-4 + ) + ), + ), f"Metric `{metric}` wasn't zero!" elif builder == ConstFactorModel: # TODO: check comperable to naive numpy compute pass @@ -205,13 +229,17 @@ def runit(params: dict): } ) for k, v in metrics.items(): - assert np.allclose( - v, - metrics2[k], - atol={torch.float32: 1e-6, torch.float64: 1e-8}[ - torch.get_default_dtype() - ], - ) + if not np.isnan(v): + assert np.allclose( + v, + metrics2[k], + atol={ + torch.float32: 1e-6, + torch.float64: 1e-8, + }[torch.get_default_dtype()], + ) + else: + assert np.isnan(metrics2[k]) # assert both are nans # Check the output XYZ batch_atoms = ase.io.read(tmpdir + "/out-orig.xyz", index=":", format="extxyz") diff --git a/tests/unit/trainer/test_loss.py b/tests/unit/trainer/test_loss.py index 101d28fb..c2885cc1 100644 --- a/tests/unit/trainer/test_loss.py +++ b/tests/unit/trainer/test_loss.py @@ -178,6 +178,7 @@ def data(float_tolerance): AtomicDataDict.TOTAL_ENERGY_KEY: torch.rand((2, 1)), "k": torch.rand((2, 1)), AtomicDataDict.ATOM_TYPE_KEY: torch.as_tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), + AtomicDataDict.STRESS_KEY: torch.rand((2, 1)), } ref = { AtomicDataDict.BATCH_KEY: torch.tensor( @@ -187,6 +188,7 @@ def data(float_tolerance): AtomicDataDict.TOTAL_ENERGY_KEY: torch.rand((2, 1)), "k": torch.rand((2, 1)), AtomicDataDict.ATOM_TYPE_KEY: torch.as_tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), + AtomicDataDict.STRESS_KEY: torch.rand((2, 1)), } yield pred, ref diff --git a/tests/unit/trainer/test_metrics.py b/tests/unit/trainer/test_metrics.py index bd29caf0..4bbd9fb0 100644 --- a/tests/unit/trainer/test_metrics.py +++ b/tests/unit/trainer/test_metrics.py @@ -27,6 +27,11 @@ ), (AtomicDataDict.FORCE_KEY, "mae", {"dim": 3}), ), + ( # test stratify settings + (AtomicDataDict.TOTAL_ENERGY_KEY, "mae", {"stratify": "10%_range"}), + (AtomicDataDict.FORCE_KEY, "rmse", {"stratify": "10%_population"}), + (AtomicDataDict.STRESS_KEY, "mae", {"stratify": 0.001}), + ), ]