Skip to content

Commit

Permalink
Merge branch 'refs/heads/develop' into stratified_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Jul 3, 2024
2 parents 026250d + 472ef3e commit 2ae5d7f
Show file tree
Hide file tree
Showing 15 changed files with 290 additions and 124 deletions.
14 changes: 13 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

Most recent change on the bottom.

## Unreleased

## Unreleased - 0.6.1
### Added
- add support for equivariance testing of arbitrary Cartesian tensor outputs
- [Breaking] use entry points for `nequip.extension`s (e.g. for field registration)
- alternate neighborlist support enabled with `NEQUIP_NL` environment variable, which can be set to `ase` (default), `matscipy` or `vesin`
- Allow `n_train` and `n_val` to be specified as percentages of datasets.
- Only attempt training restart if `trainer.pth` file present (prevents unnecessary crashes due to file-not-found errors in some cases)
- Stratified metrics now possible; stratified by reference values in percent or raw units, or by error population.

### Changed
- [Breaking] `NEQUIP_MATSCIPY_NL` environment variable no longer supported

### Fixed
- Fixed `flake8` install location in `pre-commit-config.yaml`


## [0.6.0] - 2024-5-10
### Added
- add Tensorboard as logger option
Expand Down
39 changes: 39 additions & 0 deletions nequip/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,40 @@
import sys

from ._version import __version__ # noqa: F401

import packaging.version

import torch
import warnings

# torch version checks
torch_version = packaging.version.parse(torch.__version__)

# only allow 1.11*, 1.13* or higher (no 1.12.*)
assert (torch_version > packaging.version.parse("1.11.0")) and not (
packaging.version.parse("1.12.0")
<= torch_version
< packaging.version.parse("1.13.0")
), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found"

# warn if using 1.13* or 2.0.*
if packaging.version.parse("1.13.0") <= torch_version < packaging.version.parse("2.1"):
warnings.warn(
f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.0.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue."
)


# Load all installed nequip extension packages
# This allows installed extensions to register themselves in
# the nequip infrastructure with calls like `register_fields`

# see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points

_DISCOVERED_NEQUIP_EXTENSION = entry_points(group="nequip.extension")
for ep in _DISCOVERED_NEQUIP_EXTENSION:
if ep.name == "init_always":
ep.load()
134 changes: 109 additions & 25 deletions nequip/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import os

import numpy as np
import ase.neighborlist
import ase
from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator
from ase.calculators.calculator import all_properties as ase_all_properties
from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress

import torch
import e3nn.o3
from e3nn.io import CartesianTensor

from . import AtomicDataDict
from ._util import _TORCH_INTEGER_DTYPES
Expand All @@ -26,6 +26,7 @@
# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
PBC = Union[bool, Tuple[bool, bool, bool]]

# === Key Registration ===

_DEFAULT_LONG_FIELDS: Set[str] = {
AtomicDataDict.EDGE_INDEX_KEY,
Expand Down Expand Up @@ -61,17 +62,23 @@
AtomicDataDict.CELL_KEY,
AtomicDataDict.BATCH_PTR_KEY,
}
_DEFAULT_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = {
AtomicDataDict.STRESS_KEY: "ij=ji",
AtomicDataDict.VIRIAL_KEY: "ij=ji",
}
_NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS)
_EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS)
_GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS)
_LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS)
_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = dict(_DEFAULT_CARTESIAN_TENSOR_FIELDS)


def register_fields(
node_fields: Sequence[str] = [],
edge_fields: Sequence[str] = [],
graph_fields: Sequence[str] = [],
long_fields: Sequence[str] = [],
cartesian_tensor_fields: Dict[str, str] = {},
) -> None:
r"""Register fields as being per-atom, per-edge, or per-frame.
Expand All @@ -83,18 +90,36 @@ def register_fields(
edge_fields: set = set(edge_fields)
graph_fields: set = set(graph_fields)
long_fields: set = set(long_fields)
allfields = node_fields.union(edge_fields, graph_fields)
assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields)

# error checking: prevents registering fields as contradictory types
# potentially unregistered fields
assert len(node_fields.intersection(edge_fields)) == 0
assert len(node_fields.intersection(graph_fields)) == 0
assert len(edge_fields.intersection(graph_fields)) == 0
# already registered fields
assert len(_NODE_FIELDS.intersection(edge_fields)) == 0
assert len(_NODE_FIELDS.intersection(graph_fields)) == 0
assert len(_EDGE_FIELDS.intersection(node_fields)) == 0
assert len(_EDGE_FIELDS.intersection(graph_fields)) == 0
assert len(_GRAPH_FIELDS.intersection(edge_fields)) == 0
assert len(_GRAPH_FIELDS.intersection(node_fields)) == 0

# check that Cartesian tensor fields to add are rank-2 (higher ranks not supported)
for cart_tensor_key in cartesian_tensor_fields:
cart_tensor_rank = len(
CartesianTensor(cartesian_tensor_fields[cart_tensor_key]).indices
)
if cart_tensor_rank != 2:
raise NotImplementedError(
f"Only rank-2 tensor data processing supported, but got {cart_tensor_key} is rank {cart_tensor_rank}. Consider raising a GitHub issue if higher-rank tensor data processing is desired."
)

# update fields
_NODE_FIELDS.update(node_fields)
_EDGE_FIELDS.update(edge_fields)
_GRAPH_FIELDS.update(graph_fields)
_LONG_FIELDS.update(long_fields)
if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < (
len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS)
):
raise ValueError(
"At least one key was registered as more than one of node, edge, or graph!"
)
_CARTESIAN_TENSOR_FIELDS.update(cartesian_tensor_fields)


def deregister_fields(*fields: Sequence[str]) -> None:
Expand All @@ -109,9 +134,16 @@ def deregister_fields(*fields: Sequence[str]) -> None:
assert f not in _DEFAULT_NODE_FIELDS, "Cannot deregister built-in field"
assert f not in _DEFAULT_EDGE_FIELDS, "Cannot deregister built-in field"
assert f not in _DEFAULT_GRAPH_FIELDS, "Cannot deregister built-in field"
assert f not in _DEFAULT_LONG_FIELDS, "Cannot deregister built-in field"
assert (
f not in _DEFAULT_CARTESIAN_TENSOR_FIELDS
), "Cannot deregister built-in field"

_NODE_FIELDS.discard(f)
_EDGE_FIELDS.discard(f)
_GRAPH_FIELDS.discard(f)
_LONG_FIELDS.discard(f)
_CARTESIAN_TENSOR_FIELDS.pop(f, None)


def _register_field_prefix(prefix: str) -> None:
Expand All @@ -125,6 +157,9 @@ def _register_field_prefix(prefix: str) -> None:
)


# === AtomicData ===


def _process_dict(kwargs, ignore_fields=[]):
"""Convert a dict of data into correct dtypes/shapes according to key"""
# Deal with _some_ dtype issues
Expand Down Expand Up @@ -449,17 +484,40 @@ def from_ase(
cell = kwargs.pop("cell", atoms.get_cell())
pbc = kwargs.pop("pbc", atoms.pbc)

# handle ASE-style 6 element Voigt order stress
for key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY):
if key in add_fields:
if add_fields[key].shape == (3, 3):
# it's already 3x3, do nothing else
pass
elif add_fields[key].shape == (6,):
# it's Voigt order
add_fields[key] = voigt_6_to_full_3x3_stress(add_fields[key])
# IMPORTANT: the following reshape logic only applies to rank-2 Cartesian tensor fields
for key in add_fields:
if key in _CARTESIAN_TENSOR_FIELDS:
# enforce (3, 3) shape for graph fields, e.g. stress, virial
if key in _GRAPH_FIELDS:
# handle ASE-style 6 element Voigt order stress
if key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY):
if add_fields[key].shape == (6,):
add_fields[key] = voigt_6_to_full_3x3_stress(
add_fields[key]
)
if add_fields[key].shape == (3, 3):
# it's already 3x3, do nothing else
pass
elif add_fields[key].shape == (9,):
add_fields[key] = add_fields[key].reshape((3, 3))
else:
raise RuntimeError(
f"bad shape for {key} registered as a Cartesian tensor graph field---please note that only rank-2 Cartesian tensors are currently supported"
)
# enforce (N_atom, 3, 3) shape for node fields, e.g. Born effective charges
elif key in _NODE_FIELDS:
if add_fields[key].shape[1:] == (3, 3):
pass
elif add_fields[key].shape[1:] == (9,):
add_fields[key] = add_fields[key].reshape((-1, 3, 3))
else:
raise RuntimeError(
f"bad shape for {key} registered as a Cartesian tensor node field---please note that only rank-2 Cartesian tensors are currently supported"
)
else:
raise RuntimeError(f"bad shape for {key}")
raise RuntimeError(
f"{key} registered as a Cartesian tensor field was not registered as either a graph or node field"
)

return cls.from_points(
pos=atoms.positions,
Expand Down Expand Up @@ -705,12 +763,21 @@ def without_nodes(self, which_nodes):
assert _ERROR_ON_NO_EDGES in ("true", "false")
_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true"

_NEQUIP_MATSCIPY_NL: Final[bool] = os.environ.get("NEQUIP_MATSCIPY_NL", "false").lower()
assert _NEQUIP_MATSCIPY_NL in ("true", "false")
_NEQUIP_MATSCIPY_NL = _NEQUIP_MATSCIPY_NL == "true"
# use "ase" as default
# TODO: eventually, choose fastest as default
# NOTE:
# - vesin and matscipy do not support self-interaction
# - vesin does not allow for mixed pbcs
_NEQUIP_NL: Final[str] = os.environ.get("NEQUIP_NL", "ase").lower()

if _NEQUIP_MATSCIPY_NL:
if _NEQUIP_NL == "vesin":
from vesin import NeighborList as vesin_nl
elif _NEQUIP_NL == "matscipy":
import matscipy.neighbours
elif _NEQUIP_NL == "ase":
import ase.neighborlist
else:
raise NotImplementedError(f"Unknown neighborlist NEQUIP_NL = {_NEQUIP_NL}")


def neighbor_list_and_relative_vec(
Expand Down Expand Up @@ -790,7 +857,24 @@ def neighbor_list_and_relative_vec(
# ASE dependent part
temp_cell = ase.geometry.complete_cell(temp_cell)

if _NEQUIP_MATSCIPY_NL:
if _NEQUIP_NL == "vesin":
assert strict_self_interaction and not self_interaction
# use same mixed pbc logic as
# https://github.com/Luthaf/vesin/blob/main/python/vesin/src/vesin/_ase.py
if pbc[0] and pbc[1] and pbc[2]:
periodic = True
elif not pbc[0] and not pbc[1] and not pbc[2]:
periodic = False
else:
raise ValueError(
"different periodic boundary conditions on different axes are not supported by vesin neighborlist, use ASE or matscipy"
)

first_idex, second_idex, shifts = vesin_nl(
cutoff=float(r_max), full_list=True
).compute(points=temp_pos, box=temp_cell, periodic=periodic, quantities="ijS")

elif _NEQUIP_NL == "matscipy":
assert strict_self_interaction and not self_interaction
first_idex, second_idex, shifts = matscipy.neighbours.neighbour_list(
"ijS",
Expand All @@ -799,7 +883,7 @@ def neighbor_list_and_relative_vec(
positions=temp_pos,
cutoff=float(r_max),
)
else:
elif _NEQUIP_NL == "ase":
first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list(
"ijS",
pbc,
Expand Down
5 changes: 4 additions & 1 deletion nequip/data/AtomicDataDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def with_edge_vectors(data: Type, with_lengths: bool = True) -> Type:
# (2) works on a Batch constructed from AtomicData
pos = data[_keys.POSITIONS_KEY]
edge_index = data[_keys.EDGE_INDEX_KEY]
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
# edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
edge_vec = torch.index_select(pos, 0, edge_index[1]) - torch.index_select(
pos, 0, edge_index[0]
)
if _keys.CELL_KEY in data:
# ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero.
# -1 gives a batch dim no matter what
Expand Down
2 changes: 2 additions & 0 deletions nequip/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_EDGE_FIELDS,
_GRAPH_FIELDS,
_LONG_FIELDS,
_CARTESIAN_TENSOR_FIELDS,
)
from ._dataset import (
AtomicDataset,
Expand Down Expand Up @@ -39,5 +40,6 @@
_EDGE_FIELDS,
_GRAPH_FIELDS,
_LONG_FIELDS,
_CARTESIAN_TENSOR_FIELDS,
EMTTestDataset,
]
6 changes: 1 addition & 5 deletions nequip/data/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from nequip import data
from nequip.data.transforms import TypeMapper
from nequip.data import AtomicDataset, register_fields
from nequip.data import AtomicDataset
from nequip.utils import instantiate, get_w_prefix


Expand Down Expand Up @@ -71,10 +71,6 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
# Build a TypeMapper from the config
type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config)

# Register fields:
# This might reregister fields, but that's OK:
instantiate(register_fields, all_args=config)

instance, _ = instantiate(
class_name,
prefix=prefix,
Expand Down
2 changes: 1 addition & 1 deletion nequip/data/_dataset/_base_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def statistics(
if field not in selectors:
# this means field is not selected and so not available
raise RuntimeError(
f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such using `nequip.data.register_fields`"
f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such"
)
arr = data_transformed[field]
if field in _NODE_FIELDS:
Expand Down
Loading

0 comments on commit 2ae5d7f

Please sign in to comment.