diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 2e0df955..bfa4ac09 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -818,6 +818,9 @@ def from_graphs( # Bonds atom_cart_coords = graph.atom_frac_coord @ lattice + if graph.atom_graph.dim() == 1: + # This is to avoid structure with all atoms isolated + graph.atom_graph = graph.atom_graph.reshape(0, 2) bond_basis_ag, bond_basis_bg, bond_vectors = bond_basis_expansion( center=atom_cart_coords[graph.atom_graph[:, 0]], neighbor=atom_cart_coords[graph.atom_graph[:, 1]], diff --git a/tests/test_model.py b/tests/test_model.py index e43f2498..0ae1ca8d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from pymatgen.core import Structure +from pymatgen.core import Lattice, Structure from chgnet import ROOT from chgnet.graph import CrystalGraphConverter @@ -207,6 +207,18 @@ def test_predict_batched_structures() -> None: ) +def test_predict_isolated_structures() -> None: + lattice10 = Lattice.cubic(10) + lattice20 = Lattice.cubic(20) + positions = [[0, 0, 0], [0.5, 0.5, 0.5]] + + # Create the structure + model.graph_converter.set_isolated_atom_response("ignore") + prediction10 = model.predict_structure(Structure(lattice10, ["H", "H"], positions)) + prediction20 = model.predict_structure(Structure(lattice20, ["H", "H"], positions)) + assert prediction10["e"] == pytest.approx(prediction20["e"], rel=1e-5, abs=1e-5) + + def test_as_to_from_dict() -> None: dct = model.as_dict() assert {*dct} == {"model_args", "state_dict"}