Skip to content

Commit

Permalink
Fix StructureData inconsistent labels length error (#69)
Browse files Browse the repository at this point in the history
* document RuntimeError

* use locals() instead of zip to avoid name-labels mismatch

* add test_structure_data_inconsistent_length()

* tweak test_structure_data_inconsistent_length assert
  • Loading branch information
janosh authored Sep 12, 2023
1 parent 8074497 commit 639fcd8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
18 changes: 10 additions & 8 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,20 @@ def __init__(
magmoms (list[list[float]], optional): [data_size, n_atoms, 1]
graph_converter (CrystalGraphConverter, optional): Converts the structures to
graphs. If None, it will be set to CHGNet default converter.
Raises:
RuntimeError: if the length of structures and labels (energies, forces,
stresses, magmoms) are not equal.
"""
for idx, struct in enumerate(structures):
if not isinstance(struct, Structure):
raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}")
for label, name in zip(
[energies, forces, stresses, magmoms],
["energies, forces,stresses, magmoms"],
):
if len(label) != len(structures):
raise ValueError(
f"Error! inconsistent number of structures and labels: "
f" len(structures)={len(structures)}, len({name})={len(label)})"
for name in "energies forces stresses magmoms".split():
labels = locals()[name]
if labels is not None and len(labels) != len(structures):
raise RuntimeError(
f"Inconsistent number of structures and labels: "
f"{len(structures)=}, len({name})={len(labels)}"
)
self.structures = structures
self.energies = energies
Expand Down
14 changes: 14 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,17 @@ def test_data_loader(structure_data: StructureData) -> None:
assert targets["s"][0].shape == (3, 3)
assert len(targets["m"]) == 16
assert targets["m"][0].shape == (2, 1)


def test_structure_data_inconsistent_length():
# https://github.com/CederGroupHub/chgnet/pull/69
structures = [NaCl.copy() for _ in range(5)]
energies = [np.random.random(1) for _ in range(5)]
forces = [np.random.random([2, 3]) for _ in range(4)]
with pytest.raises(RuntimeError) as exc:
StructureData(structures=structures, energies=energies, forces=forces)

assert (
str(exc.value)
== f"Inconsistent number of structures and labels: {len(structures)=}, {len(forces)=}"
)

0 comments on commit 639fcd8

Please sign in to comment.