-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #221 from mir-group/develop
0.5.5
- Loading branch information
Showing
46 changed files
with
1,325 additions
and
303 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# general | ||
root: results/w-14 | ||
run_name: minimal | ||
seed: 123 | ||
dataset_seed: 456 | ||
|
||
# network | ||
model_builders: | ||
- SimpleIrrepsConfig | ||
- EnergyModel | ||
- PerSpeciesRescale | ||
- StressForceOutput | ||
- RescaleEnergyEtc | ||
|
||
num_basis: 8 | ||
r_max: 4.0 | ||
l_max: 2 | ||
parity: true | ||
num_features: 16 | ||
|
||
# data set | ||
dataset: ase # type of data set, can be npz or ase | ||
dataset_url: https://qmml.org/Datasets/w-14.zip # url to download the npz. optional | ||
dataset_file_name: ./benchmark_data/w-14.xyz # path to data set file | ||
dataset_key_mapping: | ||
force: forces | ||
dataset_include_keys: | ||
- virial | ||
# A mapping of chemical species to type indexes is necessary if the dataset is provided with atomic numbers instead of type indexes. | ||
chemical_symbols: | ||
- W | ||
# only early frames have stress | ||
dataset_include_frames: !!python/object/apply:builtins.range | ||
- 0 | ||
- 100 | ||
- 1 | ||
|
||
global_rescale_scale: dataset_total_energy_std | ||
per_species_rescale_shifts: dataset_per_atom_total_energy_mean | ||
per_species_rescale_scales: dataset_per_atom_total_energy_std | ||
|
||
# logging | ||
wandb: false | ||
# verbose: debug | ||
|
||
# training | ||
n_train: 90 | ||
n_val: 10 | ||
batch_size: 1 | ||
max_epochs: 10 | ||
|
||
# loss function | ||
loss_coeffs: | ||
- virial | ||
- forces | ||
|
||
# optimizer | ||
optimizer_name: Adam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# general | ||
root: results/toy-emt | ||
run_name: minimal | ||
seed: 123 | ||
dataset_seed: 456 | ||
|
||
# network | ||
model_builders: | ||
- EnergyModel | ||
- PerSpeciesRescale | ||
- StressForceOutput | ||
- RescaleEnergyEtc | ||
num_basis: 8 | ||
r_max: 4.0 | ||
irreps_edge_sh: 0e + 1o | ||
conv_to_output_hidden_irreps_out: 16x0e | ||
feature_irreps_hidden: 16x0o + 16x0e + 16x1o + 16x1e | ||
|
||
# data set | ||
dataset: EMTTest # type of data set, can be npz or ase | ||
dataset_element: Cu | ||
dataset_num_frames: 100 | ||
chemical_symbols: | ||
- Cu | ||
|
||
global_rescale_scale: dataset_total_energy_std | ||
per_species_rescale_shifts: dataset_per_atom_total_energy_mean | ||
per_species_rescale_scales: dataset_per_atom_total_energy_std | ||
|
||
# logging | ||
wandb: false | ||
# verbose: debug | ||
|
||
# training | ||
n_train: 90 | ||
n_val: 10 | ||
batch_size: 1 | ||
max_epochs: 100 | ||
|
||
# loss function | ||
loss_coeffs: # different weights to use in a weighted loss functions | ||
forces: 1 # for MD applications, we recommed a force weight of 100 and an energy weight of 1 | ||
stress: 1 | ||
|
||
# optimizer | ||
optimizer_name: Adam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from typing import Dict, List, Callable, Union, Optional | ||
import numpy as np | ||
import logging | ||
|
||
import torch | ||
|
||
from nequip.data import AtomicData | ||
from nequip.utils.savenload import atomic_write | ||
from nequip.data.transforms import TypeMapper | ||
from nequip.data import AtomicDataset | ||
|
||
|
||
class ExampleCustomDataset(AtomicDataset): | ||
""" | ||
See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets. | ||
If you don't need downloading or pre-processing, just don't define any of the relevant methods/properties. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
custom_option1, | ||
custom_option2="default", | ||
type_mapper: Optional[TypeMapper] = None, | ||
): | ||
# Initialize the AtomicDataset, which runs .download() (if present) and .process() | ||
# See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets | ||
# This will only run download and preprocessing if cached dataset files aren't found | ||
super().__init__(root=root, type_mapper=type_mapper) | ||
|
||
# if the processed paths don't exist, `self.process()` has been called at this point | ||
# (if it is defined) | ||
# but otherwise you need to load the data from the cached pre-processed dir: | ||
if self.mydata is None: | ||
self.mydata = torch.load(self.processed_paths[0]) | ||
# if you didn't define `process()`, this is where you would unconditionally load your data. | ||
|
||
def len(self) -> int: | ||
"""Return the number of frames in the dataset.""" | ||
return 42 | ||
|
||
@property | ||
def raw_file_names(self) -> List[str]: | ||
"""Return a list of filenames for the raw data. | ||
Need to be simple filenames to be looked for in `self.raw_dir` | ||
""" | ||
return ["data.dat"] | ||
|
||
@property | ||
def raw_dir(self) -> str: | ||
return "/path/to/dataset-folder/" | ||
|
||
@property | ||
def processed_file_names(self) -> List[str]: | ||
"""Like `self.raw_file_names`, but for the files generated by `self.process()`. | ||
Should not be paths, just file names. These will be stored in `self.processed_dir`, | ||
which is set by NequIP in `AtomicDataset` based on `self.root` and a hash of the | ||
dataset options provided to `__init__`. | ||
""" | ||
return ["processed-data.pth"] | ||
|
||
# def download(self): | ||
# """Optional method to download raw data before preprocessing if the `raw_paths` do not exist.""" | ||
# pass | ||
|
||
def process(self): | ||
# load things from the raw data: | ||
# whatever is appropriate for your format | ||
data = np.load(self.raw_dir + "/" + self.raw_file_names[0]) | ||
|
||
# if any pre-processing is necessary, do it and cache the results to | ||
# `self.processed_paths` as you defined above: | ||
with atomic_write(self.processed_paths[0], binary=True) as f: | ||
# e.g., anything that takes a file `f` will work | ||
torch.save(data, f) | ||
# ^ use atomic writes to avoid race conditions between | ||
# different trainings that use the same dataset | ||
# since those separate trainings should all produce the same results, | ||
# it doesn't matter if they overwrite each others cached' | ||
# datasets. It only matters that they don't simultaneously try | ||
# to write the _same_ file, corrupting it. | ||
|
||
logging.info("Cached processed data to disk") | ||
|
||
# optionally, save the processed data on the Dataset object | ||
# to avoid a roundtrip from disk in `__init__` (see above) | ||
self.mydata = data | ||
|
||
def get(self, idx: int) -> AtomicData: | ||
"""Return the data frame with a given index as an `AtomicData` object.""" | ||
build_an_AtomicData_here = None | ||
return build_an_AtomicData_here | ||
|
||
def statistics( | ||
self, | ||
fields: List[Union[str, Callable]], | ||
modes: List[str], | ||
stride: int = 1, | ||
unbiased: bool = True, | ||
kwargs: Optional[Dict[str, dict]] = {}, | ||
) -> List[tuple]: | ||
"""Optional method to compute statistics over an entire dataset. | ||
This must correctly handle `self._indices` for subsets!!! | ||
If not provided, options like `avg_num_neighbors: auto`, `per_species_rescale_scales: dataset_*`, | ||
and others that compute dataset statistics will not work. This only needs to support the statistics | ||
modes that are necessary for what you need to run (i.e. if you do not use `dataset_per_species_*` | ||
statistics, you do not need to implement them). | ||
See `AtomicInMemoryDataset` for full documentation and example implementation. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.