From eb2c3f03ed5c64f8a3fd1c09bf852b87b1e95fb3 Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Tue, 12 Sep 2023 18:37:51 -0700 Subject: [PATCH] added function to specify structure ids in dataset --- chgnet/data/dataset.py | 15 +++++++++++---- tests/test_dataset.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index 4cdeb153..95023bba 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -33,6 +33,7 @@ def __init__( forces: list[Sequence[Sequence[float]]], stresses: list[Sequence[Sequence[float]]] | None = None, magmoms: list[Sequence[Sequence[float]]] | None = None, + structure_ids: list[str] | None = None, graph_converter: CrystalGraphConverter | None = None, ) -> None: """Initialize the dataset. @@ -43,8 +44,9 @@ def __init__( forces (list[list[float]]): [data_size, n_atoms, 3] stresses (list[list[float]], optional): [data_size, 3, 3] magmoms (list[list[float]], optional): [data_size, n_atoms, 1] - graph_converter (CrystalGraphConverter, optional): Converts the structures to - graphs. If None, it will be set to CHGNet default converter. + structure_ids (list[str], optional): a list of ids to track the structures + graph_converter (CrystalGraphConverter, optional): Converts the structures + to graphs. If None, it will be set to CHGNet default converter. Raises: RuntimeError: if the length of structures and labels (energies, forces, @@ -53,7 +55,7 @@ def __init__( for idx, struct in enumerate(structures): if not isinstance(struct, Structure): raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}") - for name in "energies forces stresses magmoms".split(): + for name in "energies forces stresses magmoms structure_ids".split(): labels = locals()[name] if labels is not None and len(labels) != len(structures): raise RuntimeError( @@ -65,6 +67,7 @@ def __init__( self.forces = forces self.stresses = stresses self.magmoms = magmoms + self.structure_ids = structure_ids self.keys = np.arange(len(structures)) random.shuffle(self.keys) print(f"{len(structures)} structures imported") @@ -93,8 +96,12 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]: graph_id = self.keys[idx] try: struct = self.structures[graph_id] + if self.structure_ids is not None: + mp_id = self.structure_ids[graph_id] + else: + mp_id = graph_id crystal_graph = self.graph_converter( - struct, graph_id=graph_id, mp_id=graph_id + struct, graph_id=graph_id, mp_id=mp_id ) targets = { "e": torch.tensor(self.energies[graph_id], dtype=datatype), diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 889161a2..503ac7e4 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -17,7 +17,14 @@ @pytest.fixture() def structure_data() -> StructureData: """Create a graph with 3 nodes and 3 directed edges.""" - structures, energies, forces, stresses, magmoms = [], [], [], [], [] + structures, energies, forces, stresses, magmoms, structure_ids = ( + [], + [], + [], + [], + [], + [], + ) for _ in range(100): struct = NaCl.copy() struct.perturb(0.1) @@ -26,18 +33,21 @@ def structure_data() -> StructureData: forces.append(np.random.random([2, 3])) stresses.append(np.random.random([3, 3])) magmoms.append(np.random.random([2, 1])) + structure_ids.append("tmp_id") return StructureData( structures=structures, energies=energies, forces=forces, stresses=stresses, magmoms=magmoms, + structure_ids=structure_ids, ) def test_structure_data(structure_data: StructureData) -> None: get_one = structure_data[0] assert isinstance(get_one[0], CrystalGraph) + assert get_one[0].mp_id == "tmp_id" assert isinstance(get_one[1], dict) assert isinstance(get_one[1]["e"], torch.Tensor) assert isinstance(get_one[1]["f"], torch.Tensor)