Skip to content

Commit

Permalink
added function to specify structure ids in dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Sep 13, 2023
1 parent 1d1c0af commit eb2c3f0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
15 changes: 11 additions & 4 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
forces: list[Sequence[Sequence[float]]],
stresses: list[Sequence[Sequence[float]]] | None = None,
magmoms: list[Sequence[Sequence[float]]] | None = None,
structure_ids: list[str] | None = None,
graph_converter: CrystalGraphConverter | None = None,
) -> None:
"""Initialize the dataset.
Expand All @@ -43,8 +44,9 @@ def __init__(
forces (list[list[float]]): [data_size, n_atoms, 3]
stresses (list[list[float]], optional): [data_size, 3, 3]
magmoms (list[list[float]], optional): [data_size, n_atoms, 1]
graph_converter (CrystalGraphConverter, optional): Converts the structures to
graphs. If None, it will be set to CHGNet default converter.
structure_ids (list[str], optional): a list of ids to track the structures
graph_converter (CrystalGraphConverter, optional): Converts the structures
to graphs. If None, it will be set to CHGNet default converter.
Raises:
RuntimeError: if the length of structures and labels (energies, forces,
Expand All @@ -53,7 +55,7 @@ def __init__(
for idx, struct in enumerate(structures):
if not isinstance(struct, Structure):
raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}")
for name in "energies forces stresses magmoms".split():
for name in "energies forces stresses magmoms structure_ids".split():
labels = locals()[name]
if labels is not None and len(labels) != len(structures):
raise RuntimeError(
Expand All @@ -65,6 +67,7 @@ def __init__(
self.forces = forces
self.stresses = stresses
self.magmoms = magmoms
self.structure_ids = structure_ids
self.keys = np.arange(len(structures))
random.shuffle(self.keys)
print(f"{len(structures)} structures imported")
Expand Down Expand Up @@ -93,8 +96,12 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]:
graph_id = self.keys[idx]
try:
struct = self.structures[graph_id]
if self.structure_ids is not None:
mp_id = self.structure_ids[graph_id]
else:
mp_id = graph_id
crystal_graph = self.graph_converter(
struct, graph_id=graph_id, mp_id=graph_id
struct, graph_id=graph_id, mp_id=mp_id
)
targets = {
"e": torch.tensor(self.energies[graph_id], dtype=datatype),
Expand Down
12 changes: 11 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
@pytest.fixture()
def structure_data() -> StructureData:
"""Create a graph with 3 nodes and 3 directed edges."""
structures, energies, forces, stresses, magmoms = [], [], [], [], []
structures, energies, forces, stresses, magmoms, structure_ids = (
[],
[],
[],
[],
[],
[],
)
for _ in range(100):
struct = NaCl.copy()
struct.perturb(0.1)
Expand All @@ -26,18 +33,21 @@ def structure_data() -> StructureData:
forces.append(np.random.random([2, 3]))
stresses.append(np.random.random([3, 3]))
magmoms.append(np.random.random([2, 1]))
structure_ids.append("tmp_id")
return StructureData(
structures=structures,
energies=energies,
forces=forces,
stresses=stresses,
magmoms=magmoms,
structure_ids=structure_ids,
)


def test_structure_data(structure_data: StructureData) -> None:
get_one = structure_data[0]
assert isinstance(get_one[0], CrystalGraph)
assert get_one[0].mp_id == "tmp_id"
assert isinstance(get_one[1], dict)
assert isinstance(get_one[1]["e"], torch.Tensor)
assert isinstance(get_one[1]["f"], torch.Tensor)
Expand Down

0 comments on commit eb2c3f0

Please sign in to comment.