Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratified metrics #439

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bcce773
Update .readthedocs.yaml
kavanase Apr 12, 2024
1602704
Bump `flake8` to avoid linting failure
kavanase Apr 12, 2024
a21178b
Fix typo and reformat code to satisfy now-caught `flake8` linting
kavanase Apr 12, 2024
a742ed5
Merge branch 'mir-group:develop' into develop
kavanase Apr 19, 2024
6589893
Update requirements.txt
kavanase Apr 19, 2024
d61ea38
Initial attempt at stratified metrics
kavanase Jun 26, 2024
87f08d4
Fix evaluate metrics summary being printed twice
kavanase Jun 26, 2024
9608a4d
Add initial range-stratified functionality
kavanase Jun 26, 2024
069fd02
Add population and raw unit stratification
kavanase Jun 26, 2024
6851fd7
Update changelog and add examples of stratified metrics to `full.yaml…
kavanase Jun 26, 2024
9f3bd7c
Add tests for stratified metrics (and manually tested on HPCs)
kavanase Jun 26, 2024
e09311b
Tidy up
kavanase Jun 26, 2024
0170439
Adjust identity error tolerance for stratified energy tests (can have…
kavanase Jun 26, 2024
0b5511f
Merge branch 'refs/heads/main' into stratified_metrics
kavanase Jun 27, 2024
761e980
Update `lint.yaml` to latest versions
kavanase Jun 27, 2024
026250d
Merge branch 'refs/heads/develop' into stratified_metrics
kavanase Jul 3, 2024
472ef3e
Merge branch 'refs/heads/main' into develop
kavanase Jul 3, 2024
2ae5d7f
Merge branch 'refs/heads/develop' into stratified_metrics
kavanase Jul 3, 2024
1059248
Merge branch 'mir-group:main' into stratified_metrics
kavanase Jul 10, 2024
29ef5a4
Fix hard-set parameter value in `plot_dimers.py`
kavanase Jul 5, 2024
2f8f852
`plot_dimers.py` script cleanup
kavanase Jul 5, 2024
33beed1
Add `CITATION.cff` file
kavanase Jul 12, 2024
a457e58
Update README
kavanase Jul 12, 2024
44f01c2
Fill out citation docs page
kavanase Jul 12, 2024
9e915d9
Bibtext syntax highlighting
kavanase Jul 12, 2024
54f5a8c
Update citation file
kavanase Jul 12, 2024
10f551a
Update citation file pt 2
kavanase Jul 12, 2024
ce8b8bd
Update citation file pt 3
kavanase Jul 12, 2024
c1a9d76
Update CITATION.cff
kavanase Jul 12, 2024
d1afaf2
Allow user to set `PYTORCH_VERSION_WARNING=0` to avoid many `pytorch`…
kavanase Jul 19, 2024
ea0f8fe
Merge branch 'refs/heads/develop' into stratified_metrics
kavanase Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Most recent change on the bottom.
- alternate neighborlist support enabled with `NEQUIP_NL` environment variable, which can be set to `ase` (default), `matscipy` or `vesin`
- Allow `n_train` and `n_val` to be specified as percentages of datasets.
- Only attempt training restart if `trainer.pth` file present (prevents unnecessary crashes due to file-not-found errors in some cases)
- Stratified metrics now possible; stratified by reference values in percent or raw units, or by error population.

### Changed
- [Breaking] `NEQUIP_MATSCIPY_NL` environment variable no longer supported
Expand Down
55 changes: 38 additions & 17 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,50 @@ cff-version: "1.2.0"
message: "If you use this software, please cite our article."
authors:
- family-names: Batzner
given-names: Simon
given-names: Simon
- family-names: Musaelian
given-names: Albert
given-names: Albert
- family-names: Sun
given-names: Lixin
given-names: Lixin
- family-names: Geiger
given-names: Mario
given-names: Mario
- family-names: Mailoa
given-names: Jonathan P.
given-names: Jonathan P.
- family-names: Kornbluth
given-names: Mordechai
given-names: Mordechai
- family-names: Molinari
given-names: Nicola
given-names: Nicola
- family-names: Smidt
given-names: Tess E.
given-names: Tess E.
- family-names: Kozinsky
given-names: Boris
given-names: Boris
doi: 10.1038/s41467-022-29939-5
date-published: 2022-05-04
issn: 2041-1723
journal: Nature Communications
start: 2453
title: "E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials"
type: article
url: "https://www.nature.com/articles/s41467-022-29939-5"
volume: 13
preferred-citation:
authors:
- family-names: Batzner
given-names: Simon
- family-names: Musaelian
given-names: Albert
- family-names: Sun
given-names: Lixin
- family-names: Geiger
given-names: Mario
- family-names: Mailoa
given-names: Jonathan P.
- family-names: Kornbluth
given-names: Mordechai
- family-names: Molinari
given-names: Nicola
- family-names: Smidt
given-names: Tess E.
- family-names: Kozinsky
given-names: Boris
doi: 10.1038/s41467-022-29939-5
date-published: 2022-05-04
issn: 2041-1723
journal: Nature Communications
start: 2453
title: "E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials"
type: article
url: "https://www.nature.com/articles/s41467-022-29939-5"
volume: 13
13 changes: 12 additions & 1 deletion configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,23 @@ metrics_components:
- - forces
- rmse
- PerSpecies: True
report_per_component: False
report_per_component: False
- - total_energy
- mae
- - total_energy
- mae
- PerAtom: True # if true, energy is normalized by the number of atoms
# we can also output errors stratified by the reference value ranges (in percent or absolute values), or by the error populations in percent:
- - total_energy
- mae
- stratify: 10%_range # stratify by range (in reference energies per atom), in increments of 10% (i.e. errors for first 10% lowest reference values, next 10% etc)
PerAtom: True
- - forces
- rmse
- stratify: 10%_population # stratify by population (in forces errors per atom), in increments of 10% (i.e. errors for first 10% lowest errors, next 10% etc)
- - stress
- mae
- stratify: 0.001 # stratify by absolute value (in reference stresses), in increments of 0.001

# optimizer, may be any optimizer defined in torch.optim
# the name `optimizer_name`is case sensitive
Expand Down
6 changes: 5 additions & 1 deletion nequip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys

from ._version import __version__ # noqa: F401
Expand All @@ -16,7 +17,10 @@
), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found"

# warn if using 1.13* or 2.0.*
if packaging.version.parse("1.13.0") <= torch_version:
if (
packaging.version.parse("1.13.0") <= torch_version
and int(os.environ.get("PYTORCH_VERSION_WARNING", 1)) != 0
):
warnings.warn(
f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue."
)
Expand Down
8 changes: 5 additions & 3 deletions nequip/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,13 @@ def main(args=None, running_as_script: bool = True):

if do_metrics:
logger.info("\n--- Final result: ---")
logger.critical(
logger.info(
"\n".join(
f"{k:>20s} = {v:< 20f}"
f"{k:>30s} = {v:< 30f}"
for k, v in metrics.flatten_metrics(
metrics.current_result(),
metrics.current_result(
verbose=True
), # verbose output about strata on final call
type_names=dataset.type_mapper.type_names,
)[0].items()
)
Expand Down
Loading
Loading