diff --git a/CHANGELOG.md b/CHANGELOG.md index c3a1f3dc..b4859880 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] + +## [0.3.3] - 2021-08-11 +### Added +- `to_ase` method in `AtomicData.py` to convert `AtomicData` object to (list of) `ase.Atoms` object(s) +- `SequentialGraphNetwork` now has insertion methods +- `nn.SaveForOutput` +- `nequip-evaluate` command for evaluating (metrics on) trained models +- `AtomicData.from_ase` now catches `energy`/`energies` arrays + ### Changed - Nonlinearities now specified with `e` and `o` instead of `1` and `-1` +- Update interfaces for `torch_geometric` 1.7.1 and `e3nn` 0.3.3 +- `nonlinearity_scalars` now also affects the nonlinearity used in the radial net of `InteractionBlock` +- Cleaned up naming of initializers ### Fixed - Fix specifying nonlinearities when wandb enabled +- `Final` backport for <3.8 compatability - Fixed `nequip-*` commands when using `pip install` +- Default models rescale per-atom energies, and not just total +- Fixed Python <3.8 backward compatability with `atomic_save` ## [0.3.2] - 2021-06-09 ### Added diff --git a/README.md b/README.md index 5633ab2f..59228e60 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ NequIP is an open-source code for building E(3)-equivariant interatomic potentials. +[![Documentation Status](https://readthedocs.org/projects/nequip/badge/?version=latest)](https://nequip.readthedocs.io/en/latest/?badge=latest) ![nequip](./nequip.png) @@ -12,35 +13,18 @@ NequIP is an open-source code for building E(3)-equivariant interatomic potentia NequIP requires: * Python >= 3.6 -* PyTorch = 1.8 +* PyTorch >= 1.8 To install: -* Install [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric), make sure to install this with your correct version of CUDA/CPU and to use PyTorch Geometric version 1.7.0: - -``` -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+${CUDA}.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+${CUDA}.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+${CUDA}.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+${CUDA}.html -pip install torch-geometric==1.7.0 -pip install e3nn==0.2.9 -``` - -where ```${CUDA}``` should be replaced by either ```cpu```, ```cu101```, ```cu102```, or ```cu111``` depending on your PyTorch installation, for details see [here](https://github.com/rusty1s/pytorch_geometric). - -* Install [e3nn](https://github.com/e3nn/e3nn), version 0.2.9: - -``` -pip install e3nn==0.2.9 -``` +* Install [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric), following [their installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) and making sure to install with the correct version of CUDA. (Please note that `torch_geometric>=1.7.1)` is required.) * Install our fork of [`pytorch_ema`](https://github.com/Linux-cpp-lisp/pytorch_ema) for using an Exponential Moving Average on the weights: ```bash -$ pip install git+https://github.com/Linux-cpp-lisp/pytorch_ema@context_manager#egg=torch_ema +$ pip install "git+https://github.com/Linux-cpp-lisp/pytorch_ema@context_manager#egg=torch_ema" ``` -* We use [Weights&Biases](https://wandb.ai) to keep track of experiments. This is not a strict requirement, you can use our package without this, but it may make your life easier. If you want to use it, create an account [here](https://wandb.ai) and install it: +* We use [Weights&Biases](https://wandb.ai) to keep track of experiments. This is not a strict requirement — you can use our package without it — but it may make your life easier. If you want to use it, create an account [here](https://wandb.ai) and install the Python package: ``` pip install wandb @@ -56,13 +40,15 @@ pip install . ### Installation Issues -We recommend running the tests using ```pytest``` on a CPU: +We recommend running the tests using ```pytest```: ``` pip install pytest pytest ./tests/ ``` +While the tests are somewhat compute intensive, we've known them to hang on certain systems that have GPUs. If this happens to you, please report it along with information on your software environment in the [Issues](https://github.com/mir-group/nequip/issues)! + ## Usage **! PLEASE NOTE:** the first few training epochs/calls to a NequIP model can be painfully slow. This is expected behaviour as the [profile-guided optimization of TorchScript models](https://program-transformations.github.io/slides/pytorch_neurips.pdf) takes a number of calls to warm up before optimizing the model. This occurs regardless of whether the entire model is compiled because many core components from e3nn are compiled and optimized through TorchScript. @@ -83,9 +69,33 @@ A number of example configuration files are provided: Training runs can be restarted using `nequip-restart`; training that starts fresh or restarts depending on the existance of the working directory can be launched using `nequip-requeue`. All `nequip-*` commands accept the `--help` option to show their call signatures and options. -### In-depth tutorial +### Evaluating trained models (and their error) + +The `nequip-evaluate` command can be used to evaluate a trained model on a specified dataset, optionally computing error metrics or writing the results to an XYZ file for further processing. -A more in-depth introduction to the internals of NequIP can be found in the [tutorial notebook](https://deepnote.com/project/2412ca93-7ad1-4458-972c-5d5add5a667e). +The simplest command is: +```bash +$ nequip-evaluate --train-dir /path/to/training/session/ +``` +which will evaluate the original training error metrics over any part of the original dataset not used in the training or validation sets. + +For more details on this command, please run `nequip-evaluate --help`. + +### Deploying models + +The `nequip-deploy` command is used to deploy the result of a training session into a model that can be stored and used for inference. +It compiles a NequIP model trained in Python to [TorchScript](https://pytorch.org/docs/stable/jit.html). +The result is an optimized model file that has no dependency on the `nequip` Python library, or even on Python itself: +```bash +nequip-deploy build path/to/training/session/ path/to/deployed.pth +``` +For more details on this command, please run `nequip-deploy --help`. + +### Using models in Python + +Both deployed and undeployed models can be used in Python code; for details, see the end of the [Developer's tutorial](https://deepnote.com/project/2412ca93-7ad1-4458-972c-5d5add5a667e) mentioned again below. + +An ASE calculator is also provided in `nequip.dynamics`. ### LAMMPS Integration @@ -93,12 +103,11 @@ NequIP is integrated with the popular Molecular Dynamics code [LAMMPS](https://w The interface is implemented as `pair_style nequip`. Using it requires two simple steps: -1. Deploy a trained NequIP model. This step compiles a NequIP model trained in Python to [TorchScript](https://pytorch.org/docs/stable/jit.html). -The result is an optimized model file that has no Python dependency and can be used by standalone C++ programs such as LAMMPS: - +1. Deploy a trained NequIP model, as discussed above. ``` nequip-deploy build path/to/training/session/ path/to/deployed.pth ``` +The result is an optimized model file that has no Python dependency and can be used by standalone C++ programs such as LAMMPS. 2. Change the LAMMPS input file to the nequip `pair_style` and point it to the deployed NequIP model: @@ -107,29 +116,35 @@ pair_style nequip pair_coeff * * deployed.pth ``` -For installation instructions, please see the NequIP `pair_style` repo at https://github.com/mir-group/pair_nequip. +For installation instructions, please see the [`pair_nequip` repository](https://github.com/mir-group/pair_nequip). -## References +## Developer's tutorial -The theory behind NequIP is described in our preprint [1]. NequIP's backend builds on e3nn, a general framework for building E(3)-equivariant neural networks [2]. +A more in-depth introduction to the internals of NequIP can be found in the [tutorial notebook](https://deepnote.com/project/2412ca93-7ad1-4458-972c-5d5add5a667e). This notebook discusses theoretical background as well as the Python interfaces that can be used to train and call models. - [1] https://arxiv.org/abs/2101.03164 - [2] https://github.com/e3nn/e3nn +Please note that for most common usecases, including customized models, the `nequip-*` commands should be prefered for training models. + +## References & citing + +The theory behind NequIP is described in our preprint (1). NequIP's backend builds on e3nn, a general framework for building E(3)-equivariant neural networks (2). If you use this repository in your work, please consider citing NequIP (1) and e3nn (3): + + 1. https://arxiv.org/abs/2101.03164 + 2. https://e3nn.org + 3. https://doi.org/10.5281/zenodo.3724963 ## Authors NequIP is being developed by: - - Simon Batzner - - Albert Musaelian - - Lixin Sun - - Anders Johansson - - Mario Geiger - - Tess Smidt - -under the guidance of Boris Kozinsky at Harvard. + - Simon Batzner + - Albert Musaelian + - Lixin Sun + - Anders Johansson + - Mario Geiger + - Tess Smidt +under the guidance of [Boris Kozinsky at Harvard](https://bkoz.seas.harvard.edu/). ## Contact & questions @@ -137,10 +152,3 @@ If you have questions, please don't hesitate to reach out at batzner[at]g[dot]ha If you find a bug or have a proposal for a feature, please post it in the [Issues](https://github.com/mir-group/nequip/issues). If you have a question, topic, or issue that isn't obviously one of those, try our [GitHub Disucssions](https://github.com/mir-group/nequip/discussions). - -## Citing - -If you use this repository in your work, please consider citing NequIP (1) and e3nn (2): - - [1] https://arxiv.org/abs/2101.03164 - [2] https://doi.org/10.5281/zenodo.3724963 diff --git a/configs/full.yaml b/configs/full.yaml index 2fce3bf3..7ae8d632 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -12,6 +12,7 @@ seed: 0 restart: false # set True for a restarted run append: false # set True if a restarted run should append to the previous log file default_dtype: float32 # type of float to use, e.g. float32 and float64 +allow_tf32: True # whether to use TensorFloat32 if it is available # network r_max: 4.0 # cutoff radius in length units @@ -170,11 +171,15 @@ optimizer_weight_decay: 0 # weight initialization # this can be the importable name of any function that can be `model.apply`ed to initialize some weights in the model. NequIP provides a number of useful initializers: +# For more details please see the docstrings of the individual initializers #model_initializers: # - nequip.utils.initialization.uniform_initialize_fcs -# - nequip.utils.initialization.uniform_initialize_tps -# - nequip.utils.initialization.orthogonal_initialize_linears -# - nequip.utils.initialization.uniform_initialize_linears +# - nequip.utils.initialization.uniform_initialize_equivariant_linears +# - nequip.utils.initialization.uniform_initialize_tp_internal_weights +# - nequip.utils.initialization.xavier_initialize_fcs +# - nequip.utils.initialization.(unit_)orthogonal_initialize_equivariant_linears +# - nequip.utils.initialization.(unit_)orthogonal_initialize_fcs +# - nequip.utils.initialization.(unit_)orthogonal_initialize_e3nn_fcs # lr scheduler, currently only supports the two options listed below, if you need more please file an issue # first: on-plateau, reduce lr by factory of lr_scheduler_factor if metrics_key hasn't improved for lr_scheduler_patience epoch diff --git a/configs/minimal.yaml b/configs/minimal.yaml index e8cb2917..a45810f9 100644 --- a/configs/minimal.yaml +++ b/configs/minimal.yaml @@ -11,11 +11,6 @@ conv_to_output_hidden_irreps_out: 16x0e feature_irreps_hidden: 16x0o + 16x0e + 16x1o + 16x1e + 16x2o + 16x2e model_uniform_init: false -model_initializers: - - nequip.utils.initialization.uniform_initialize_fcs - - nequip.utils.initialization.uniform_initialize_tps - - nequip.utils.initialization.orthogonal_initialize_linears - # data dataset: aspirin dataset_file_name: benchmark_data/aspirin_ccsd-train.npz diff --git a/nequip/_version.py b/nequip/_version.py index f021bdb0..93a83158 100644 --- a/nequip/_version.py +++ b/nequip/_version.py @@ -2,4 +2,4 @@ # See Python packaging guide # https://packaging.python.org/guides/single-sourcing-package-version/ -__version__ = "0.3.2" +__version__ = "0.3.3" diff --git a/nequip/data/AtomicData.py b/nequip/data/AtomicData.py index e4c679e8..252e5b72 100644 --- a/nequip/data/AtomicData.py +++ b/nequip/data/AtomicData.py @@ -5,11 +5,12 @@ import warnings from copy import deepcopy -from typing import Union, Tuple, Dict, Optional +from typing import Union, Tuple, Dict, Optional, List from collections.abc import Mapping import numpy as np import ase.neighborlist +import ase from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator import torch @@ -200,9 +201,10 @@ def from_ase(cls, atoms, r_max, **kwargs): Respects ``atoms``'s ``pbc`` and ``cell``. - Automatically recognize force, energy (overridden by free energy tag) - get_atomic_numbers() will be stored as the atomic_numbers attributes + First tries to extract energies and forces from a single-point calculator associated with the ``Atoms`` if one is present and has those fields. + If either is not found, the method will look for ``energy``/``energies`` and ``force``/``forces`` in ``atoms.arrays``. + `get_atomic_numbers()` will be stored as the atomic_numbers attribute. Args: atoms (ase.Atoms): the input. @@ -233,10 +235,19 @@ def from_ase(cls, atoms, r_max, **kwargs): "energy" ) - elif "forces" in atoms.arrays: - add_fields[AtomicDataDict.FORCE_KEY] = atoms.arrays["forces"] - elif "force" in atoms.arrays: - add_fields[AtomicDataDict.FORCE_KEY] = atoms.arrays["force"] + if AtomicDataDict.FORCE_KEY not in add_fields: + # Get it from arrays + for k in ("force", "forces"): + if k in atoms.arrays: + add_fields[AtomicDataDict.FORCE_KEY] = atoms.arrays[k] + break + + if AtomicDataDict.TOTAL_ENERGY_KEY not in add_fields: + # Get it from arrays + for k in ("energy", "energies"): + if k in atoms.arrays: + add_fields[AtomicDataDict.TOTAL_ENERGY_KEY] = atoms.arrays[k] + break add_fields[AtomicDataDict.ATOMIC_NUMBERS_KEY] = atoms.get_atomic_numbers() @@ -249,6 +260,72 @@ def from_ase(cls, atoms, r_max, **kwargs): **add_fields, ) + def to_ase(self) -> Union[List[ase.Atoms], ase.Atoms]: + """Build a (list of) ``ase.Atoms`` object(s) from an ``AtomicData`` object. + + For each unique batch number provided in ``AtomicDataDict.BATCH_KEY``, + an ``ase.Atoms`` object is created. If ``AtomicDataDict.BATCH_KEY`` does not + exist in self, a single ``ase.Atoms`` object is created. + + Returns: + A list of ``ase.Atoms`` objects if ``AtomicDataDict.BATCH_KEY`` is in self + and is not None. Otherwise, a single ``ase.Atoms`` object is returned. + """ + positions = self.pos + if positions.device != torch.device("cpu"): + raise TypeError( + "Explicitly move this `AtomicData` to CPU using `.to()` before calling `to_ase()`." + ) + atomic_nums = self.atomic_numbers + pbc = getattr(self, AtomicDataDict.PBC_KEY, None) + cell = getattr(self, AtomicDataDict.CELL_KEY, None) + batch = getattr(self, AtomicDataDict.BATCH_KEY, None) + energy = getattr(self, AtomicDataDict.TOTAL_ENERGY_KEY, None) + force = getattr(self, AtomicDataDict.FORCE_KEY, None) + do_calc = energy is not None or force is not None + + if cell is not None: + cell = cell.view(-1, 3, 3) + if pbc is not None: + pbc = pbc.view(-1, 3) + + if batch is not None: + n_batches = batch.max() + 1 + cell = cell.expand(n_batches, 3, 3) if cell is not None else None + pbc = pbc.expand(n_batches, 3) if pbc is not None else None + else: + n_batches = 1 + + batch_atoms = [] + for batch_idx in range(n_batches): + if batch is not None: + mask = batch == batch_idx + else: + mask = slice(None) + + mol = ase.Atoms( + numbers=atomic_nums[mask], + positions=positions[mask], + cell=cell[batch_idx] if cell is not None else None, + pbc=pbc[batch_idx] if pbc is not None else None, + ) + + if do_calc: + fields = {} + if energy is not None: + fields["energy"] = energy[batch_idx].cpu().numpy() + if force is not None: + fields["forces"] = force[mask].cpu().numpy() + mol.calc = SinglePointCalculator(mol, **fields) + + batch_atoms.append(mol) + + if batch is not None: + return batch_atoms + else: + assert len(batch_atoms) == 1 + return batch_atoms[0] + def get_edge_vectors(data: Data) -> torch.Tensor: data = AtomicDataDict.with_edge_vectors(AtomicData.to_AtomicDataDict(data)) return data[AtomicDataDict.EDGE_VECTORS_KEY] @@ -263,8 +340,15 @@ def to_AtomicDataDict( keys = data.keys() else: raise ValueError(f"Invalid data `{repr(data)}`") + return { - k: data[k] for k in keys if (k not in exclude_keys and data[k] is not None) + k: data[k] + for k in keys + if ( + k not in exclude_keys + and data[k] is not None + and isinstance(data[k], torch.Tensor) + ) } @classmethod @@ -322,7 +406,7 @@ def without_nodes(self, which_nodes): elif k == AtomicDataDict.CELL_KEY: new_dict[k] = self[k] else: - if len(self[k]) == self.num_nodes: + if isinstance(self[k], torch.Tensor) and len(self[k]) == self.num_nodes: new_dict[k] = self[k][mask] else: new_dict[k] = self[k] diff --git a/nequip/data/AtomicDataDict.py b/nequip/data/AtomicDataDict.py index 196c04ee..598b3ff1 100644 --- a/nequip/data/AtomicDataDict.py +++ b/nequip/data/AtomicDataDict.py @@ -30,10 +30,11 @@ def validate_keys(keys, graph_required=True): raise KeyError("At least pos and edge_index must be supplied") if _keys.EDGE_CELL_SHIFT_KEY in keys and "cell" not in keys: raise ValueError("If `edge_cell_shift` given, `cell` must be given.") - if _keys.ATOMIC_NUMBERS_KEY in keys and _keys.SPECIES_INDEX_KEY in keys: - raise ValueError( - "'atomic_numbers' and 'species_index' cannot be simultaneously provided" - ) + # This is in flux; TODO + # if _keys.ATOMIC_NUMBERS_KEY in keys and _keys.SPECIES_INDEX_KEY in keys: + # raise ValueError( + # "'atomic_numbers' and 'species_index' cannot be simultaneously provided" + # ) _SPECIAL_IRREPS = [None] diff --git a/nequip/data/dataloader.py b/nequip/data/dataloader.py index 728a3c2f..6d1debb7 100644 --- a/nequip/data/dataloader.py +++ b/nequip/data/dataloader.py @@ -12,7 +12,7 @@ def __init__(self, fixed_fields=[], exclude_keys=[]): self._exclude_keys = set(exclude_keys) @classmethod - def for_dataset(cls, dataset, exclude_keys=None): + def for_dataset(cls, dataset, exclude_keys=[]): return cls( fixed_fields=list(getattr(dataset, "fixed_fields", {}).keys()), exclude_keys=exclude_keys, diff --git a/nequip/data/dataset.py b/nequip/data/dataset.py index bdc4f737..a0bcc446 100644 --- a/nequip/data/dataset.py +++ b/nequip/data/dataset.py @@ -6,9 +6,11 @@ """ import numpy as np import logging - +import tempfile from os.path import dirname, basename, abspath -from typing import Tuple, Dict, Any, List, Callable, Union, Optional +from typing import Tuple, Dict, Any, List, Callable, Union, Optional, Sequence + +import ase import torch from torch_geometric.data import Batch, Dataset, download_url, extract_zip @@ -54,6 +56,7 @@ def statistics( class AtomicInMemoryDataset(AtomicDataset): r"""Base class for all datasets that fit in memory. + Please note that, as a ``pytorch_geometric`` dataset, it must be backed by some kind of disk storage. By default, the raw file will be stored at root/raw and the processed torch file will be at root/process. @@ -64,10 +67,10 @@ class AtomicInMemoryDataset(AtomicDataset): Subclasses may implement: - ``download()`` or ``self.url`` or ``ClassName.URL`` - Args: + Args: + root (str, optional): Root directory where the dataset should be saved. Defaults to current working directory. file_name (str, optional): file name of data source. only used in children class url (str, optional): url to download data source - root (str, optional): Root directory where the dataset should be saved. Defaults to current working directory. force_fixed_keys (list, optional): keys to move from AtomicData to fixed_fields dictionary extra_fixed_fields (dict, optional): extra key that are not stored in data but needed for AtomicData initialization include_frames (list, optional): the frames to process with the constructor. @@ -123,20 +126,6 @@ def __init__( f"please delete the processed folder and rerun {self.processed_paths[0]}" ) - @classmethod - def from_data_list(cls, data_list: List[AtomicData], **kwargs): - """Make an ``AtomicInMemoryDataset`` from a list of ``AtomicData`` objects. - - Args: - data_list (List[AtomicData]) - **kwargs: passed through to the constructor - Returns: - The constructed ``AtomicInMemoryDataset``. - """ - obj = cls(**kwargs) - obj.get_data = lambda: (data_list,) - return obj - def len(self): if self.data is None: return 0 @@ -295,8 +284,8 @@ def statistics( unbiased: bool = True, modes: Optional[List[Union[str]]] = None, ) -> List[tuple]: - if self.__indices__ is not None: - selector = torch.as_tensor(self.__indices__)[::stride] + if self._indices is not None: + selector = torch.as_tensor(self._indices)[::stride] else: selector = torch.arange(0, self.len(), stride) @@ -482,18 +471,38 @@ def __init__( ) @classmethod - def from_atoms(cls, atoms: list, **kwargs): + def from_atoms_list(cls, atoms: Sequence[ase.Atoms], **kwargs): """Make an ``ASEDataset`` from a list of ``ase.Atoms`` objects. + If `root` is not provided, a temporary directory will be used. + + Please note that this is a convinience method that does NOT avoid a round-trip to disk; the provided ``atoms`` will be written out to a file. + + Ignores ``kwargs["file_name"]`` if it is provided. + Args: - atoms (List[ase.Atoms]) + atoms **kwargs: passed through to the constructor Returns: The constructed ``ASEDataset``. """ - # TO DO, this funciton fails. It also needs to be unit tested + if "root" not in kwargs: + tmpdir = tempfile.TemporaryDirectory() + kwargs["root"] = tmpdir.name + else: + tmpdir = None + kwargs["file_name"] = tmpdir.name + "/atoms.xyz" + atoms = list(atoms) + # Write them out + ase.io.write(kwargs["file_name"], atoms, format="extxyz") + # Read them in obj = cls(**kwargs) - obj.get_atoms = lambda: atoms + if tmpdir is not None: + # Make it keep a reference to the tmpdir to keep it alive + # When the dataset is garbage collected, the tmpdir will + # be too, and will (hopefully) get deleted eventually. + # Or at least by end of program... + obj._tmpdir_ref = tmpdir return obj @property diff --git a/nequip/datasets/aspirin.py b/nequip/datasets/aspirin.py index d6096f40..01557e0d 100644 --- a/nequip/datasets/aspirin.py +++ b/nequip/datasets/aspirin.py @@ -2,14 +2,11 @@ from os.path import dirname, basename, abspath -from ase import units -from ase.io import read - from nequip.data import AtomicDataDict, AtomicInMemoryDataset class AspirinDataset(AtomicInMemoryDataset): - """Aspirin DFT/CCSD(T) data """ + """Aspirin DFT/CCSD(T) data""" URL = "http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip" FILE_NAME = "benchmark_data/aspirin_ccsd-train.npz" @@ -30,8 +27,7 @@ def get_data(self): AtomicDataDict.TOTAL_ENERGY_KEY: data["E"].reshape([-1, 1]), } fixed_fields = { - AtomicDataDict.ATOMIC_NUMBERS_KEY: np.asarray(data["z"], dtype=np.int), + AtomicDataDict.ATOMIC_NUMBERS_KEY: np.asarray(data["z"], dtype=int), AtomicDataDict.PBC_KEY: np.array([False, False, False]), - AtomicDataDict.CELL_KEY: None, } return arrays, fixed_fields diff --git a/nequip/nn/__init__.py b/nequip/nn/__init__.py index 2a58ea2d..f86102b8 100644 --- a/nequip/nn/__init__.py +++ b/nequip/nn/__init__.py @@ -9,3 +9,4 @@ from ._grad_output import GradientOutput, ForceOutput # noqa: F401 from ._rescale import RescaleOutput # noqa: F401 from ._convnetlayer import ConvNetLayer # noqa: F401 +from ._util import SaveForOutput # noqa: F401 diff --git a/nequip/nn/_graph_mixin.py b/nequip/nn/_graph_mixin.py index 8499ef7f..9498201b 100644 --- a/nequip/nn/_graph_mixin.py +++ b/nequip/nn/_graph_mixin.py @@ -229,6 +229,53 @@ def append_from_parameters( self.append(name, instance) return + def insert(self, after: str, name: str, module: GraphModuleMixin) -> None: + """Insert a module after the module with name ``after``. + + Args: + after: the module to insert after + name: the name of the module to insert + module: the moldule to insert + """ + # This checks names, etc. + self.add_module(name, module) + # Now insert in the right place by overwriting + names = list(self._modules.keys()) + modules = list(self._modules.values()) + idx = names.index(after) + names.insert(idx + 1, name) + modules.insert(idx + 1, module) + self._modules = OrderedDict(zip(names, modules)) + return + + def insert_from_parameters( + self, + after: str, + shared_params: Mapping, + name: str, + builder: Callable, + params: Dict[str, Any] = {}, + ) -> None: + r"""Build a module from parameters and insert it after ``after``. + + Args: + after: the name of the module to insert after + shared_params (dict-like): shared parameters from which to pull when instantiating the module + name (str): the name for the module + builder (callable): a class or function to build a module + params (dict, optional): extra specific parameters for this module that take priority over those in ``shared_params`` + """ + idx = list(self._modules.keys()).index(after) + instance, _ = instantiate( + builder=builder, + prefix=name, + positional_args=(dict(irreps_in=self[idx].irreps_out)), + optional_args=params, + all_args=shared_params, + ) + self.insert(after, name, instance) + return + # Copied from https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential # with type annotations added def forward(self, input: AtomicDataDict.Type) -> AtomicDataDict.Type: diff --git a/nequip/nn/_interaction_block.py b/nequip/nn/_interaction_block.py index f0b01b32..8370cf85 100644 --- a/nequip/nn/_interaction_block.py +++ b/nequip/nn/_interaction_block.py @@ -1,5 +1,5 @@ """ Interaction Block """ -from typing import Optional +from typing import Optional, Dict, Callable import torch @@ -26,6 +26,7 @@ def __init__( invariant_neurons=8, avg_num_neighbors=None, use_sc=False, + nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp"}, ) -> None: """ InteractionBlock. @@ -114,7 +115,10 @@ def __init__( [self.irreps_in[AtomicDataDict.EDGE_EMBEDDING_KEY].num_irreps] + invariant_layers * [invariant_neurons] + [tp.weight_numel], - ShiftedSoftPlus, + { + "ssp": ShiftedSoftPlus, + "silu": torch.nn.functional.silu, + }[nonlinearity_scalars["e"]], ) self.tp = tp diff --git a/nequip/nn/_rescale.py b/nequip/nn/_rescale.py index 10ccc7ff..a946eb80 100644 --- a/nequip/nn/_rescale.py +++ b/nequip/nn/_rescale.py @@ -29,6 +29,9 @@ class RescaleOutput(GraphModuleMixin, torch.nn.Module): scale_keys: List[str] shift_keys: List[str] + trainable_global_rescale_scale: bool + trainable_global_rescale_shift: bool + _has_scale: bool _has_shift: bool @@ -104,6 +107,19 @@ def __init__( # register dummy for TorchScript self.register_buffer("shift_by", torch.Tensor()) + # Finally, we tell all the modules in the model that there is rescaling + # This allows them to update parameters, like physical constants with units, + # that need to be scaled + # + # Note that .modules() walks the full tree, including self + for mod in self.model.modules(): + if isinstance(mod, GraphModuleMixin): + callback = getattr(mod, "update_for_rescale", None) + if callable(callback): + # It gets the `RescaleOutput` as an argument, + # since that contains all relevant information + callback(self) + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: data = self.model(data) if self.training: diff --git a/nequip/nn/_util.py b/nequip/nn/_util.py new file mode 100644 index 00000000..95c3f969 --- /dev/null +++ b/nequip/nn/_util.py @@ -0,0 +1,29 @@ +import torch + +from nequip.data import AtomicDataDict +from nequip.nn import GraphModuleMixin + + +class SaveForOutput(torch.nn.Module, GraphModuleMixin): + """Copy a field and disconnect it from the autograd graph. + + Copy a field and disconnect it from the autograd graph, storing it under another key for inspection as part of the models output. + + Args: + field: the field to save + out_field: the key to put the saved copy in + """ + + field: str + out_field: str + + def __init__(self, field: str, out_field: str, irreps_in=None): + super().__init__() + self._init_irreps(irreps_in=irreps_in) + self.irreps_out[out_field] = self.irreps_in[field] + self.field = field + self.out_field = out_field + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: + data[self.out_field] = data[self.field].detach().clone() + return data diff --git a/nequip/scripts/deploy.py b/nequip/scripts/deploy.py index 32ec7582..6c7a4e66 100644 --- a/nequip/scripts/deploy.py +++ b/nequip/scripts/deploy.py @@ -1,4 +1,10 @@ -from typing import Final, Tuple, Dict, Union +import sys + +if sys.version_info[1] >= 8: + from typing import Final +else: + from typing_extensions import Final +from typing import Tuple, Dict, Union import argparse import pathlib import logging diff --git a/nequip/scripts/evaluate.py b/nequip/scripts/evaluate.py new file mode 100644 index 00000000..420bcc9e --- /dev/null +++ b/nequip/scripts/evaluate.py @@ -0,0 +1,298 @@ +import sys +import argparse +import textwrap +from pathlib import Path +import contextlib +from tqdm.auto import tqdm + +import ase.io + +import torch + +from nequip.utils import Config, dataset_from_config +from nequip.data import AtomicData, Collater +from nequip.scripts.deploy import load_deployed_model +from nequip.utils import load_file, instantiate +from nequip.train.loss import Loss +from nequip.train.metrics import Metrics + + +def main(args=None): + # in results dir, do: nequip-deploy build . deployed.pth + parser = argparse.ArgumentParser( + description=textwrap.dedent( + """Compute the error of a model on a test set using various metrics. + + The model, metrics, dataset, etc. can specified individually, or a training session can be indicated with `--train-dir`. + + Prints only the final result in `name = num` format to stdout; all other information is printed to stderr. + + WARNING: Please note that results of CUDA models are rarely exactly reproducible, and that even CPU models can be nondeterministic. + """ + ) + ) + parser.add_argument( + "--train-dir", + help="Path to a working directory from a training session.", + type=Path, + default=None, + ) + parser.add_argument( + "--model", + help="A deployed or pickled NequIP model to load. If omitted, defaults to `best_model.pth` in `train_dir`.", + type=Path, + default=None, + ) + parser.add_argument( + "--dataset-config", + help="A YAML config file specifying the dataset to load test data from. If omitted, `config_final.yaml` in `train_dir` will be used", + type=Path, + default=None, + ) + parser.add_argument( + "--metrics-config", + help="A YAML config file specifying the metrics to compute. If omitted, `config_final.yaml` in `train_dir` will be used. If the config does not specify `metrics_components`, the default is to print MAEs and RMSEs for all fields given in the loss function. If the literal string `None`, no metrics will be computed.", + type=str, + default=None, + ) + parser.add_argument( + "--test-indexes", + help="Path to a file containing the indexes in the dataset that make up the test set. If omitted, all data frames *not* used as training or validation data in the training session `train_dir` will be used.", + type=Path, + default=None, + ) + parser.add_argument( + "--batch-size", + help="Batch size to use. Larger is usually faster on GPU.", + type=int, + default=50, + ) + parser.add_argument( + "--device", + help="Device to run the model on. If not provided, defaults to CUDA if available and CPU otherwise.", + type=str, + default=None, + ) + parser.add_argument( + "--output", + help="XYZ file to write out the test set and model predicted forces, energies, etc. to.", + type=Path, + default=None, + ) + # Something has to be provided + # See https://stackoverflow.com/questions/22368458/how-to-make-argparse-print-usage-when-no-option-is-given-to-the-code + if len(sys.argv) == 1: + parser.print_help() + parser.exit() + # Parse the args + args = parser.parse_args(args=args) + + # Do the defaults: + dataset_is_from_training: bool = False + if args.train_dir: + if args.dataset_config is None: + args.dataset_config = args.train_dir / "config_final.yaml" + dataset_is_from_training = True + if args.metrics_config is None: + args.metrics_config = args.train_dir / "config_final.yaml" + if args.model is None: + args.model = args.train_dir / "best_model.pth" + if args.test_indexes is None: + # Find the remaining indexes that arent train or val + trainer = torch.load( + str(args.train_dir / "trainer.pth"), map_location="cpu" + ) + train_idcs = set(trainer["train_idcs"].tolist()) + val_idcs = set(trainer["val_idcs"].tolist()) + else: + train_idcs = val_idcs = None + # update + if args.metrics_config == "None": + args.metrics_config = None + elif args.metrics_config is not None: + args.metrics_config = Path(args.metrics_config) + do_metrics = args.metrics_config is not None + # validate + if args.dataset_config is None: + raise ValueError("--dataset-config or --train-dir must be provided") + if args.metrics_config is None and args.output is None: + raise ValueError( + "Nothing to do! Must provide at least one of --metrics-config, --train-dir (to use training config for metrics), or --output" + ) + if args.model is None: + raise ValueError("--model or --train-dir must be provided") + if args.output is not None: + if args.output.suffix != ".xyz": + raise ValueError("Only extxyz format for `--output` is supported.") + + if args.device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(args.device) + print(f"Using device: {device}", file=sys.stderr) + if device.type == "cuda": + print( + "WARNING: please note that models running on CUDA are usually nondeterministc and that this manifests in the final test errors; for a _more_ deterministic result, please use `--device cpu`", + file=sys.stderr, + ) + + # Load model: + print("Loading model... ", file=sys.stderr, end="") + try: + model, _ = load_deployed_model(args.model, device=device) + print("loaded deployed model.", file=sys.stderr) + except ValueError: # its not a deployed model + model = torch.load(args.model, map_location=device) + model = model.to(device) + print("loaded pickled Python model.", file=sys.stderr) + + # Load a config file + print( + f"Loading {'original training ' if dataset_is_from_training else ''}dataset...", + file=sys.stderr, + ) + config = Config.from_file(str(args.dataset_config)) + + # Currently, pytorch_geometric prints some status messages to stdout while loading the dataset + # TODO: fix may come soon: https://github.com/rusty1s/pytorch_geometric/pull/2950 + # Until it does, just redirect them. + with contextlib.redirect_stdout(sys.stderr): + dataset = dataset_from_config(config) + + c = Collater.for_dataset(dataset, exclude_keys=[]) + + # Determine the test set + # this makes no sense if a dataset is given seperately + if train_idcs is not None and dataset_is_from_training: + # we know the train and val, get the rest + all_idcs = set(range(len(dataset))) + # set operations + test_idcs = list(all_idcs - train_idcs - val_idcs) + assert set(test_idcs).isdisjoint(train_idcs) + assert set(test_idcs).isdisjoint(val_idcs) + print( + f"Using training dataset minus training and validation frames, yielding a test set size of {len(test_idcs)} frames.", + file=sys.stderr, + ) + if not do_metrics: + print( + "WARNING: using the automatic test set ^^^ but not computing metrics, is this really what you wanted to do?", + file=sys.stderr, + ) + else: + # load from file + test_idcs = load_file( + supported_formats=dict( + torch=["pt", "pth"], yaml=["yaml", "yml"], json=["json"] + ), + filename=str(args.test_indexes), + ) + print( + f"Using provided test set indexes, yielding a test set size of {len(test_idcs)} frames.", + file=sys.stderr, + ) + + # Figure out what metrics we're actually computing + if do_metrics: + metrics_config = Config.from_file(str(args.metrics_config)) + metrics_components = metrics_config.get("metrics_components", None) + # See trainer.py: init() and init_metrics() + # Default to loss functions if no metrics specified: + if metrics_components is None: + loss, _ = instantiate( + builder=Loss, + prefix="loss", + positional_args=dict(coeffs=metrics_config.loss_coeffs), + all_args=metrics_config, + ) + metrics_components = [] + for key, func in loss.funcs.items(): + params = { + "PerSpecies": type(func).__name__.startswith("PerSpecies"), + } + metrics_components.append((key, "mae", params)) + metrics_components.append((key, "rmse", params)) + + metrics, _ = instantiate( + builder=Metrics, + prefix="metrics", + positional_args=dict(components=metrics_components), + all_args=metrics_config, + ) + metrics.to(device=device) + + batch_i: int = 0 + batch_size: int = args.batch_size + + print("Starting...", file=sys.stderr) + context_stack = contextlib.ExitStack() + with contextlib.ExitStack() as context_stack: + # "None" checks if in a TTY and disables if not + prog = context_stack.enter_context(tqdm(total=len(test_idcs), disable=None)) + if do_metrics: + display_bar = context_stack.enter_context( + tqdm( + bar_format="" + if prog.disable # prog.ncols doesn't exist if disabled + else ("{desc:." + str(prog.ncols) + "}"), + disable=None, + ) + ) + + if args.output is not None: + output = context_stack.enter_context(open(args.output, "w")) + else: + output = None + + while True: + datas = [ + dataset.get(int(idex)) + for idex in test_idcs[batch_i * batch_size : (batch_i + 1) * batch_size] + ] + if len(datas) == 0: + break + batch = c.collate(datas) + batch = batch.to(device) + out = model(AtomicData.to_AtomicDataDict(batch)) + + with torch.no_grad(): + # Write output + if output is not None: + ase.io.write( + output, + AtomicData.from_AtomicDataDict(out).to(device="cpu").to_ase(), + format="extxyz", + append=True, + ) + # Accumulate metrics + if do_metrics: + metrics(out, batch) + display_bar.set_description_str( + " | ".join( + f"{k} = {v:4.2f}" + for k, v in metrics.flatten_metrics( + metrics.current_result() + )[0].items() + ) + ) + + batch_i += 1 + prog.update(batch.num_graphs) + + prog.close() + if do_metrics: + display_bar.close() + + if do_metrics: + print(file=sys.stderr) + print("--- Final result: ---", file=sys.stderr) + print( + "\n".join( + f"{k:>20s} = {v:< 20f}" + for k, v in metrics.flatten_metrics(metrics.current_result())[0].items() + ) + ) + + +if __name__ == "__main__": + main() diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index e487d991..455185c7 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -28,6 +28,7 @@ model_initializers=[], dataset_statistics_stride=1, default_dtype="float32", + allow_tf32=True, verbose="INFO", model_debug_mode=False, equivariance_test=False, @@ -79,6 +80,14 @@ def _load_callable(obj: Union[str, Callable]) -> Callable: def fresh_start(config): # = Set global state = + # Set TF32 support + # See https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if torch.cuda.is_available(): + if torch.torch.backends.cuda.matmul.allow_tf32 and not config.allow_tf32: + # it is enabled, and we dont want it to, so disable: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + if config.model_debug_mode: set_irreps_debug(enabled=True) torch.set_default_dtype( @@ -217,6 +226,11 @@ def fresh_start(config): [AtomicDataDict.FORCE_KEY] if AtomicDataDict.FORCE_KEY in core_model.irreps_out else [] + ) + + ( + [AtomicDataDict.PER_ATOM_ENERGY_KEY] + if AtomicDataDict.PER_ATOM_ENERGY_KEY in core_model.irreps_out + else [] ), scale_by=global_scale, shift_keys=AtomicDataDict.TOTAL_ENERGY_KEY, diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index 049efcdd..cf1e2b0d 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -13,9 +13,10 @@ import yaml from copy import deepcopy from os.path import isfile -from time import perf_counter +from time import perf_counter, gmtime, strftime from typing import Optional, Union + if sys.version_info[1] >= 7: import contextlib else: @@ -23,9 +24,12 @@ import contextlib2 as contextlib import numpy as np +import e3nn +import torch_geometric import torch from torch_ema import ExponentialMovingAverage +import nequip from nequip.data import DataLoader, AtomicData, AtomicDataDict from nequip.utils import ( Output, @@ -312,13 +316,15 @@ def __init__( # initialize the optimizer and scheduler, the params will be updated in the function self.init() - self.statistics = {} - if not (restart and append): + d = self.as_dict() for key in list(d.keys()): if not isinstance(d[key], (float, int, str, list, tuple)): - d[key] = type(d[key]) + d[key] = repr(d[key]) + + d["start_time"] = strftime("%a, %d %b %Y %H:%M:%S", gmtime()) + self.log_dictionary(d, name="Initialization") logging.debug("! Done Initialize Trainer") @@ -395,6 +401,9 @@ def as_dict(self, state_dict: bool = False, training_progress: bool = False): dictionary["progress"]["last_model_path"] = self.last_model_path dictionary["progress"]["trainer_save_path"] = self.trainer_save_path + for code in [e3nn, nequip, torch, torch_geometric]: + dictionary[f"{code.__name__}_version"] = code.__version__ + return dictionary def save(self, filename, format=None): @@ -460,6 +469,15 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): d = deepcopy(dictionary) + for code in [e3nn, nequip, torch, torch_geometric]: + version = d.get(f"{code.__name__}_version", None) + if version is not None and version != code.__version__: + logging.warning( + "Loading a pickled model created with different library version(s) may cause issues." + f"current {code.__name__} verion: {code.__version__} " + f"vs original version: {version}" + ) + # update the restart and append option d["restart"] = True if append is not None: diff --git a/nequip/utils/initialization.py b/nequip/utils/initialization.py index 2e8c0e02..70ef1fe9 100644 --- a/nequip/utils/initialization.py +++ b/nequip/utils/initialization.py @@ -13,21 +13,23 @@ def unit_uniform_init_(t: torch.Tensor): def uniform_initialize_fcs(mod: torch.nn.Module): - """Initialize ``e3nn.nn.FullyConnectedNet``s with ``unit_uniform_init_``""" + """Initialize ``e3nn.nn.FullyConnectedNet``s with ``unit_uniform_init_`` + + No need to do torch.nn.Linear, which is uniform by default. + """ if isinstance(mod, e3nn.nn.FullyConnectedNet): - for w in mod.weights: - unit_uniform_init_(w) - # no need to do torch.nn.Linear, which is uniform by default + for layer in mod: + unit_uniform_init_(layer.weight) -def uniform_initialize_linears(mod: torch.nn.Module): - """Initialize ``e3nn.o3.Linear``s with ``unit_uniform_init_``""" +def uniform_initialize_equivariant_linears(mod: torch.nn.Module): + """Initialize ``e3nn.o3.Linear``s that have internal weights with ``unit_uniform_init_``""" if isinstance(mod, e3nn.o3.Linear) and mod.internal_weights: unit_uniform_init_(mod.weight) -def uniform_initialize_tps(mod: torch.nn.Module): - """Initialize ``e3nn.o3.TensorProduct``s with ``unit_uniform_init_``""" +def uniform_initialize_tp_internal_weights(mod: torch.nn.Module): + """Initialize ``e3nn.o3.TensorProduct``s that have internal weights with ``unit_uniform_init_``""" if isinstance(mod, e3nn.o3.TensorProduct) and mod.internal_weights: unit_uniform_init_(mod.weight) @@ -36,40 +38,66 @@ def uniform_initialize_tps(mod: torch.nn.Module): def xavier_initialize_fcs(mod: torch.nn.Module): """Initialize ``e3nn.nn.FullyConnectedNet``s and ``torch.nn.Linear``s with Xavier uniform initialization""" if isinstance(mod, e3nn.nn.FullyConnectedNet): - for w in mod.weights: + for layer in mod: # in FC: # h_in, _h_out = W.shape # W = W / h_in**0.5 - torch.nn.init.xavier_uniform_(w, gain=w.shape[0] ** 0.5) + torch.nn.init.xavier_uniform_( + layer.weight, gain=layer.weight.shape[0] ** 0.5 + ) elif isinstance(mod, torch.nn.Linear): torch.nn.init.xavier_uniform_(mod.weight) # == Orthogonal == +# TODO: does this normalization make any sense def unit_orthogonal_init_(t: torch.Tensor): """Orthogonal init with = 1""" assert t.ndim == 2 torch.nn.init.orthogonal_(t, gain=math.sqrt(max(t.shape))) -def orthogonal_initialize_linears(mod: torch.nn.Module): - """Initialize ``e3nn.o3.Linear``s with ``unit_orthogonal_init_``""" +def unit_orthogonal_initialize_equivariant_linears(mod: torch.nn.Module): + """Initialize ``e3nn.o3.Linear``s that have internal weights with ``unit_orthogonal_init_``""" if isinstance(mod, e3nn.o3.Linear) and mod.internal_weights: for w in mod.weight_views(): - unit_uniform_init_(w) + unit_orthogonal_init_(w) -def orthogonal_initialize_fcs(mod: torch.nn.Module): - """Initialize ``e3nn.nn.FullyConnectedNet``s and ``torch.nn.Linear``s with orthogonal initialization""" +def unit_orthogonal_initialize_fcs(mod: torch.nn.Module): + """Initialize ``e3nn.nn.FullyConnectedNet``s and ``torch.nn.Linear``s with ``unit_orthogonal_init_``""" if isinstance(mod, e3nn.nn.FullyConnectedNet): - for w in mod.weights: - torch.nn.init.orthogonal_(w) + for layer in mod: + unit_orthogonal_init_(layer.weight) elif isinstance(mod, torch.nn.Linear): - torch.nn.init.orthogonal_(mod.weight) + unit_orthogonal_init_(mod.weight) def unit_orthogonal_initialize_e3nn_fcs(mod: torch.nn.Module): """Initialize only ``e3nn.nn.FullyConnectedNet``s with ``unit_orthogonal_init_``""" if isinstance(mod, e3nn.nn.FullyConnectedNet): - for w in mod.weights: - unit_orthogonal_init_(w) + for layer in mod: + unit_orthogonal_init_(layer.weight) + + +def orthogonal_initialize_equivariant_linears(mod: torch.nn.Module): + """Initialize ``e3nn.o3.Linear``s that have internal weights with ``torch.nn.init.orthogonal_``""" + if isinstance(mod, e3nn.o3.Linear) and mod.internal_weights: + for w in mod.weight_views(): + torch.nn.init.orthogonal_(w) + + +def orthogonal_initialize_fcs(mod: torch.nn.Module): + """Initialize ``e3nn.nn.FullyConnectedNet``s and ``torch.nn.Linear``s with ``torch.nn.init.orthogonal_``""" + if isinstance(mod, e3nn.nn.FullyConnectedNet): + for layer in mod: + torch.nn.init.orthogonal_(layer.weight) + elif isinstance(mod, torch.nn.Linear): + torch.nn.init.orthogonal_(mod.weight) + + +def orthogonal_initialize_e3nn_fcs(mod: torch.nn.Module): + """Initialize only ``e3nn.nn.FullyConnectedNet``s with ``torch.nn.init.orthogonal_``""" + if isinstance(mod, e3nn.nn.FullyConnectedNet): + for layer in mod: + torch.nn.init.orthogonal_(layer.weight) diff --git a/nequip/utils/savenload.py b/nequip/utils/savenload.py index 202b6bf5..8c87a853 100644 --- a/nequip/utils/savenload.py +++ b/nequip/utils/savenload.py @@ -2,6 +2,7 @@ utilities that involve file searching and operations (i.e. save/load) """ from typing import Union +import sys import logging import contextlib from pathlib import Path @@ -22,7 +23,13 @@ def atomic_write(filename: Union[Path, str]): tmp_path.rename(filename) finally: # clean up - tmp_path.unlink(missing_ok=True) + # better for python 3.8 > + if sys.version_info[1] >= 8: + tmp_path.unlink(missing_ok=True) + else: + # race condition? + if tmp_path.exists(): + tmp_path.unlink() def save_file( diff --git a/setup.py b/setup.py index 9dea66e7..db624516 100644 --- a/setup.py +++ b/setup.py @@ -22,15 +22,17 @@ "nequip-train = nequip.scripts.train:main", "nequip-restart = nequip.scripts.restart:main", "nequip-requeue = nequip.scripts.requeue:main", + "nequip-evaluate = nequip.scripts.evaluate:main", "nequip-deploy = nequip.scripts.deploy:main", ] }, install_requires=[ "numpy", "ase", + "tqdm", "torch>=1.8", - "torch_geometric", - "e3nn>=0.3", + "torch_geometric>=1.7.1", + "e3nn>=0.3.3", "pyyaml", "contextlib2;python_version<'3.7'", # backport of nullcontext "typing_extensions;python_version<'3.8'", diff --git a/tests/conftest.py b/tests/conftest.py index 8786d675..4a03bcb3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,13 +3,13 @@ import pathlib import pytest import tempfile -import torch from ase.atoms import Atoms from ase.build import molecule from ase.calculators.singlepoint import SinglePointCalculator from ase.io import write +import torch from torch_geometric.data import Batch from nequip.utils.test import set_irreps_debug @@ -25,8 +25,24 @@ # Suppress linter errors float_tolerance = float_tolerance +# - Ampere and TF32 - +# Many of the tests for NequIP involve numerically checking +# algebraic properties— normalization, equivariance, +# continuity, etc. +# With the added numerical noise of TF32, some of those tests fail +# with the current (and usually generous) thresholds. +# +# Thus we go on the assumption that PyTorch + NVIDIA got everything +# right, that this setting DOES NOT AFFECT the model outputs except +# for increased numerical noise, and only test without it. +# +# TODO: consider running tests with and without +# TODO: check how much thresholds have to be changed to accomidate TF32 +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False -@pytest.fixture() + +@pytest.fixture(scope="session") def BENCHMARK_ROOT(): return pathlib.Path(__file__).parent / "../benchmark_data/" diff --git a/tests/data/test_AtomicData.py b/tests/data/test_AtomicData.py index 845a2931..4af9cd27 100644 --- a/tests/data/test_AtomicData.py +++ b/tests/data/test_AtomicData.py @@ -1,7 +1,9 @@ import pytest +import copy import numpy as np import torch +from torch_geometric.data import Batch import ase.build import ase.geometry @@ -19,6 +21,40 @@ def test_from_ase(CuFcc): assert data[key].shape == (len(atoms), 3) # 4 species in this atoms +def test_to_ase(CH3CHO): + atoms, data = CH3CHO + to_ase_atoms = data.to_ase() + assert np.allclose(atoms.get_positions(), to_ase_atoms.get_positions()) + assert np.array_equal(atoms.get_atomic_numbers(), to_ase_atoms.get_atomic_numbers()) + assert np.array_equal(atoms.get_pbc(), to_ase_atoms.get_pbc()) + assert np.array_equal(atoms.get_cell(), to_ase_atoms.get_cell()) + + +def test_to_ase_batches(atomic_batch): + atomic_data = AtomicData.from_dict(atomic_batch.to_dict()) + to_ase_atoms_batch = atomic_data.to_ase() + for batch_idx, atoms in enumerate(to_ase_atoms_batch): + mask = atomic_data.batch == batch_idx + assert atoms.get_positions().shape == (len(atoms), 3) + assert np.allclose(atoms.get_positions(), atomic_data.pos[mask]) + assert atoms.get_atomic_numbers().shape == (len(atoms),) + assert np.array_equal( + atoms.get_atomic_numbers(), atomic_data.atomic_numbers[mask] + ) + assert np.array_equal(atoms.get_cell(), atomic_data.cell[batch_idx]) + assert np.array_equal(atoms.get_pbc(), atomic_data.pbc[batch_idx]) + + +def test_ase_roundtrip(CuFcc): + atoms, data = CuFcc + atoms2 = data.to_ase() + assert np.allclose(atoms.get_positions(), atoms2.get_positions()) + assert np.array_equal(atoms.get_atomic_numbers(), atoms2.get_atomic_numbers()) + assert np.array_equal(atoms.get_pbc(), atoms2.get_pbc()) + assert np.allclose(atoms.get_cell(), atoms2.get_cell()) + assert np.allclose(atoms.calc.results["forces"], atoms2.calc.results["forces"]) + + def test_non_periodic_edge(CH3CHO): atoms, data = CH3CHO # check edges @@ -62,7 +98,7 @@ def test_without_nodes(CH3CHO): assert new_data.edge_index.min() >= 0 assert new_data.edge_index.max() == new_data.num_nodes - 1 - which_nodes_mask = np.zeros(len(atoms), dtype=np.bool) + which_nodes_mask = np.zeros(len(atoms), dtype=bool) which_nodes_mask[[0, 1, 2, 4]] = True new_data = data.without_nodes(which_nodes=which_nodes_mask) assert new_data.num_nodes == len(atoms) - np.sum(which_nodes_mask) @@ -157,6 +193,21 @@ def test_silicon_neighbors(Si): assert edge_index_set_equiv(data.edge_index, edge_index_true) +def test_batching(Si): + _, orig = Si + N = 4 + datas = [] + for _ in range(N): + new = copy.deepcopy(orig) + new.pos += torch.randn_like(new.pos) + datas.append(new) + batch = Batch.from_data_list(datas) + for i, orig in enumerate(datas): + new = batch.get_example(i) + for k, v in orig: + assert torch.equal(v, new[k]) + + def edge_index_set_equiv(a, b): """Compare edge_index arrays in an unordered way.""" # [[0, 1], [1, 0]] -> {(0, 1), (1, 0)} diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 7131d225..871ecf96 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -67,9 +67,9 @@ def npz_dataset(): Z=np.random.randint(1, 108, size=(nframes, natoms)), ) with tempfile.TemporaryDirectory() as folder: - np.savez(folder + "npzdata.npz", **npz) + np.savez(folder + "/npzdata.npz", **npz) a = NpzDataset( - file_name=folder + "npzdata.npz", + file_name=folder + "/npzdata.npz", root=folder, extra_fixed_fields={"r_max": 3}, ) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 3fb63cb5..c75b5d82 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -180,3 +180,15 @@ def test_ase(self, ase_file, root): a = dataset_from_config(config) assert isdir(a.root) assert isdir(f"{a.root}/processed") + + +class TestFromList: + def test_from_atoms(self, molecules): + dataset = ASEDataset.from_atoms_list( + molecules, extra_fixed_fields={"r_max": 4.5} + ) + assert len(dataset) == len(molecules) + for i, mol in enumerate(molecules): + assert np.array_equal( + mol.get_atomic_numbers(), dataset.get(i).to_ase().get_atomic_numbers() + ) diff --git a/tests/model/test_eng_force.py b/tests/model/test_eng_force.py index b1bb8406..a8db666e 100644 --- a/tests/model/test_eng_force.py +++ b/tests/model/test_eng_force.py @@ -13,7 +13,7 @@ from nequip.data import AtomicDataDict, AtomicData from nequip.models import EnergyModel, ForceModel from nequip.nn import GraphModuleMixin, AtomwiseLinear -from nequip.utils.initialization import uniform_initialize_linears +from nequip.utils.initialization import uniform_initialize_equivariant_linears from nequip.utils.test import assert_AtomicData_equivariant @@ -117,7 +117,7 @@ def test_weight_init(self, model, atomic_batch, device): out_orig = instance(data)[out_field] with torch.no_grad(): - instance.apply(uniform_initialize_linears) + instance.apply(uniform_initialize_equivariant_linears) out_unif = instance(data)[out_field] assert not torch.allclose(out_orig, out_unif) diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py index 43f3890a..a081ff34 100644 --- a/tests/nn/test_sequential.py +++ b/tests/nn/test_sequential.py @@ -38,3 +38,30 @@ def test_append(): } ) assert out["thing"].shape == out[AtomicDataDict.NODE_FEATURES_KEY].shape + + +def test_insert(): + sgn = SequentialGraphNetwork.from_parameters( + shared_params={"num_species": 3}, + layers={"one_hot": OneHotAtomEncoding, "lin2": AtomwiseLinear}, + ) + sgn.insert_from_parameters( + after="one_hot", + shared_params={"out_field": "thing"}, + name="lin1", + builder=AtomwiseLinear, + params={"out_field": AtomicDataDict.NODE_FEATURES_KEY}, + ) + assert isinstance(sgn.lin1, AtomwiseLinear) + assert len(sgn) == 3 + assert sgn[0] is sgn.one_hot + assert sgn[1] is sgn.lin1 + assert sgn[2] is sgn.lin2 + out = sgn( + { + AtomicDataDict.POSITIONS_KEY: torch.randn(5, 3), + AtomicDataDict.EDGE_INDEX_KEY: torch.LongTensor([[0, 1], [1, 0]]), + AtomicDataDict.SPECIES_INDEX_KEY: torch.LongTensor([0, 0, 1, 2, 0]), + } + ) + assert AtomicDataDict.NODE_FEATURES_KEY in out diff --git a/tests/nn/test_utils.py b/tests/nn/test_utils.py new file mode 100644 index 00000000..ee645924 --- /dev/null +++ b/tests/nn/test_utils.py @@ -0,0 +1,29 @@ +import torch + +from nequip.data import AtomicDataDict +from nequip.nn.embedding import OneHotAtomEncoding +from nequip.nn import SequentialGraphNetwork, SaveForOutput, AtomwiseLinear + + +def test_basic(): + sgn = SequentialGraphNetwork.from_parameters( + shared_params={"num_species": 4}, + layers={ + "one_hot": OneHotAtomEncoding, + "save": ( + SaveForOutput, + dict(field=AtomicDataDict.NODE_FEATURES_KEY, out_field="saved"), + ), + "linear": AtomwiseLinear, + }, + ) + out = sgn( + { + AtomicDataDict.POSITIONS_KEY: torch.randn(5, 3), + AtomicDataDict.EDGE_INDEX_KEY: torch.LongTensor([[0, 1], [1, 0]]), + AtomicDataDict.SPECIES_INDEX_KEY: torch.LongTensor([0, 0, 1, 2, 0]), + } + ) + saved = out["saved"] + assert saved.shape == (5, 4) + assert torch.all(saved[0] == torch.as_tensor([1.0, 0.0, 0.0, 0.0])) diff --git a/tests/scripts/test_deploy.py b/tests/scripts/test_deploy.py index 9bdc66dc..c0b32124 100644 --- a/tests/scripts/test_deploy.py +++ b/tests/scripts/test_deploy.py @@ -16,9 +16,9 @@ def test_deploy(nequip_dataset, BENCHMARK_ROOT): dtype = str(torch.get_default_dtype())[len("torch.") :] - if torch.cuda.is_available(): - # TODO: is this true? - pytest.skip("CUDA and subprocesses have issues") + # if torch.cuda.is_available(): + # # TODO: is this true? + # pytest.skip("CUDA and subprocesses have issues") config_path = pathlib.Path(__file__).parents[2] / "configs/minimal.yaml" true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader) @@ -50,12 +50,13 @@ def test_deploy(nequip_dataset, BENCHMARK_ROOT): assert deployed_path.is_file(), "Deploy didn't create file" # now test predictions the same - data = AtomicData.to_AtomicDataDict(nequip_dataset.get(0)) + best_mod = torch.load(f"{tmpdir}/{run_name}/best_model.pth") + device = next(best_mod.parameters()).device + data = AtomicData.to_AtomicDataDict(nequip_dataset.get(0).to(device)) # Needed because of debug mode: data[AtomicDataDict.TOTAL_ENERGY_KEY] = data[ AtomicDataDict.TOTAL_ENERGY_KEY ].unsqueeze(0) - best_mod = torch.load(f"{tmpdir}/{run_name}/best_model.pth") train_pred = best_mod(data)[AtomicDataDict.TOTAL_ENERGY_KEY] # load model and check that metadata saved @@ -63,14 +64,17 @@ def test_deploy(nequip_dataset, BENCHMARK_ROOT): deploy.NEQUIP_VERSION_KEY: "", deploy.R_MAX_KEY: "", } - deploy_mod = torch.jit.load(deployed_path, _extra_files=metadata) + deploy_mod = torch.jit.load( + deployed_path, _extra_files=metadata, map_location="cpu" + ) # Everything we store right now is ASCII, so decode for printing metadata = {k: v.decode("ascii") for k, v in metadata.items()} assert metadata[deploy.NEQUIP_VERSION_KEY] == nequip.__version__ assert np.allclose(float(metadata[deploy.R_MAX_KEY]), true_config["r_max"]) + data = AtomicData.to_AtomicDataDict(nequip_dataset.get(0).to("cpu")) deploy_pred = deploy_mod(data)[AtomicDataDict.TOTAL_ENERGY_KEY] - assert torch.allclose(train_pred, deploy_pred, atol=1e-7) + assert torch.allclose(train_pred.to("cpu"), deploy_pred, atol=1e-7) # now test info retcode = subprocess.run( diff --git a/tests/scripts/test_evaluate.py b/tests/scripts/test_evaluate.py new file mode 100644 index 00000000..472de3aa --- /dev/null +++ b/tests/scripts/test_evaluate.py @@ -0,0 +1,188 @@ +import pytest +import tempfile +import pathlib +import yaml +import subprocess +import os +import textwrap +import shutil + +import numpy as np +import torch + +from nequip.data import AtomicDataDict + +from test_train import ConstFactorModel, IdentityModel # noqa + + +@pytest.fixture( + scope="module", + params=[ + ("minimal.yaml", AtomicDataDict.FORCE_KEY), + ], +) +def conffile(request): + return request.param + + +@pytest.fixture(scope="module", params=[ConstFactorModel, IdentityModel]) +def training_session(request, BENCHMARK_ROOT, conffile): + conffile, _ = conffile + builder = request.param + dtype = str(torch.get_default_dtype())[len("torch.") :] + + # if torch.cuda.is_available(): + # # TODO: is this true? + # pytest.skip("CUDA and subprocesses have issues") + + path_to_this_file = pathlib.Path(__file__) + config_path = path_to_this_file.parents[2] / f"configs/{conffile}" + true_config = yaml.load(config_path.read_text(), Loader=yaml.Loader) + with tempfile.TemporaryDirectory() as tmpdir: + # == Run training == + # Save time + run_name = "test_train_" + dtype + true_config["run_name"] = run_name + true_config["root"] = tmpdir + true_config["dataset_file_name"] = str( + BENCHMARK_ROOT / "aspirin_ccsd-train.npz" + ) + true_config["default_dtype"] = dtype + true_config["max_epochs"] = 2 + true_config["model_builder"] = builder + + # to be a true identity, we can't have rescaling + true_config["global_rescale_shift"] = None + true_config["global_rescale_scale"] = None + + config_path = tmpdir + "/conf.yaml" + with open(config_path, "w+") as fp: + yaml.dump(true_config, fp) + # == Train model == + env = dict(os.environ) + # make this script available so model builders can be loaded + env["PYTHONPATH"] = ":".join( + [str(path_to_this_file.parent)] + env.get("PYTHONPATH", "").split(":") + ) + retcode = subprocess.run( + ["nequip-train", str(config_path)], cwd=tmpdir, env=env + ) + retcode.check_returncode() + + yield builder, true_config, tmpdir, env + + +@pytest.mark.parametrize("do_test_idcs", [True, False]) +@pytest.mark.parametrize("do_metrics", [True, False]) +def test_metrics(training_session, do_test_idcs, do_metrics): + builder, true_config, tmpdir, env = training_session + # == Run test error == + outdir = f"{true_config['root']}/{true_config['run_name']}/" + + default_params = {"train-dir": outdir, "output": tmpdir + "/out.xyz"} + + def runit(params: dict): + tmp = default_params.copy() + tmp.update(params) + params = tmp + del tmp + retcode = subprocess.run( + ["nequip-evaluate"] + + sum( + (["--" + k, str(v)] for k, v in params.items() if v is not None), + start=[], + ), + cwd=tmpdir, + env=env, + stdout=subprocess.PIPE, + ) + retcode.check_returncode() + + # Check the output + metrics = dict( + [ + tuple(e.strip() for e in line.split("=", 1)) + for line in retcode.stdout.decode().splitlines() + ] + ) + metrics = {k: float(v) for k, v in metrics.items()} + return metrics + + # Test idcs + if do_test_idcs: + # The Aspirin dataset is 1000 frames long + # Pick some arbitrary number of frames + test_idcs_arr = torch.randperm(1000)[:257] + test_idcs = tmpdir + "/some-test-idcs.pth" + torch.save(test_idcs_arr, test_idcs) + else: + test_idcs = None # ignore and use default + default_params["test-indexes"] = test_idcs + + # Metrics + if do_metrics: + # Write an explicit metrics file + metrics_yaml = tmpdir + "/my-metrics.yaml" + with open(metrics_yaml, "w") as f: + # Write out a fancier metrics file + # We don't use PerSpecies here since the simple models don't fill SPECIES_INDEX right now + # ^ TODO! + f.write( + textwrap.dedent( + """ + metrics_components: + - - forces + - rmse + - report_per_component: True + """ + ) + ) + expect_metrics = {"f_rmse_0", "f_rmse_1", "f_rmse_2"} + else: + metrics_yaml = None + # Regardless of builder, with minimal.yaml, we should have RMSE and MAE + expect_metrics = {"f_mae", "f_rmse"} + default_params["metrics-config"] = metrics_yaml + + # First run + metrics = runit({"train-dir": outdir, "batch-size": 200, "device": "cpu"}) + # move out.xyz to out-orig.xyz + shutil.move(tmpdir + "/out.xyz", tmpdir + "/out-orig.xyz") + + assert set(metrics.keys()) == expect_metrics + + if builder == IdentityModel: + for metric, err in metrics.items(): + assert np.allclose(err, 0.0), f"Metric `{metric}` wasn't zero!" + elif builder == ConstFactorModel: + # TODO: check comperable to naive numpy compute + pass + + # Check insensitive to batch size + for batch_size in (13, 1000): + metrics2 = runit( + {"train-dir": outdir, "batch-size": batch_size, "device": "cpu"} + ) + for k, v in metrics.items(): + assert np.all(np.abs(v - metrics2[k]) < 1e-5) + # Diff the output XYZ, which shouldn't change at all + # Use `cmp`, which is UNIX standard, to make efficient + # See https://stackoverflow.com/questions/12900538/fastest-way-to-tell-if-two-files-have-the-same-contents-in-unix-linux + cmp_retval = subprocess.run( + ["cmp", "--silent", tmpdir + "/out-orig.xyz", tmpdir + "/out.xyz"] + ) + if cmp_retval.returncode == 0: + # same + pass + if cmp_retval.returncode == 1: + raise AssertionError( + f"Changing batch size to {batch_size} changed out.xyz!" + ) + else: + cmp_retval.check_returncode() # error out for subprocess problem + + # Check GPU + if torch.cuda.is_available(): + metrics_gpu = runit({"train-dir": outdir, "batch-size": 17, "device": "cuda"}) + for k, v in metrics.items(): + assert np.all(np.abs(v - metrics_gpu[k]) < 1e-3) # GPU nondeterminism diff --git a/tests/scripts/test_train.py b/tests/scripts/test_train.py index 06179ba6..c6d798b1 100644 --- a/tests/scripts/test_train.py +++ b/tests/scripts/test_train.py @@ -90,9 +90,9 @@ def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, field, builder): dtype = str(torch.get_default_dtype())[len("torch.") :] - if torch.cuda.is_available(): - # TODO: is this true? - pytest.skip("CUDA and subprocesses have issues") + # if torch.cuda.is_available(): + # # TODO: is this true? + # pytest.skip("CUDA and subprocesses have issues") path_to_this_file = pathlib.Path(__file__) config_path = path_to_this_file.parents[2] / f"configs/{conffile}" @@ -187,4 +187,6 @@ def test_metrics(nequip_dataset, BENCHMARK_ROOT, conffile, field, builder): one = model.model.one # Since the loss is always zero, even though the constant # 1 was trainable, it shouldn't have changed - assert torch.allclose(one, torch.ones(1)) + assert torch.allclose( + one, torch.ones(1, device=one.device, dtype=one.dtype) + )