Skip to content

Commit

Permalink
Support for edge case of structures with all isolated atoms
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Jul 5, 2024
1 parent b1bc8a2 commit b819ef5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
3 changes: 3 additions & 0 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,9 @@ def from_graphs(

# Bonds
atom_cart_coords = graph.atom_frac_coord @ lattice
if graph.atom_graph.dim() == 1:
# This is to avoid structure with all atoms isolated
graph.atom_graph = graph.atom_graph.reshape(0, 2)
bond_basis_ag, bond_basis_bg, bond_vectors = bond_basis_expansion(
center=atom_cart_coords[graph.atom_graph[:, 0]],
neighbor=atom_cart_coords[graph.atom_graph[:, 1]],
Expand Down
14 changes: 13 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import pytest
from pymatgen.core import Structure
from pymatgen.core import Lattice, Structure

from chgnet import ROOT
from chgnet.graph import CrystalGraphConverter
Expand Down Expand Up @@ -207,6 +207,18 @@ def test_predict_batched_structures() -> None:
)


def test_predict_isolated_structures() -> None:
lattice10 = Lattice.cubic(10)
lattice20 = Lattice.cubic(20)
positions = [[0, 0, 0], [0.5, 0.5, 0.5]]

# Create the structure
model.graph_converter.set_isolated_atom_response("ignore")
prediction10 = model.predict_structure(Structure(lattice10, ["H", "H"], positions))
prediction20 = model.predict_structure(Structure(lattice20, ["H", "H"], positions))
assert prediction10["e"] == pytest.approx(prediction20["e"], rel=1e-5, abs=1e-5)


def test_as_to_from_dict() -> None:
dct = model.as_dict()
assert {*dct} == {"model_args", "state_dict"}
Expand Down

0 comments on commit b819ef5

Please sign in to comment.