Skip to content

Commit

Permalink
revamp model_dtype system
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Nov 24, 2024
1 parent 039714d commit ece09b5
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 35 deletions.
3 changes: 1 addition & 2 deletions nequip/data/AtomicDataDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ def frame_from_batched(batched_data: Type, index: int) -> Type:
elif k in _key_registry._EDGE_FIELDS: # excluding edge indices
out[k] = v[torch.eq(torch.index_select(batches, 0, edge_center_idx), index)]
else:
if k != _keys.MODEL_DTYPE_KEY:
raise KeyError(f"Unregistered key {k}")
raise KeyError(f"Unregistered key {k}")

return out

Expand Down
2 changes: 0 additions & 2 deletions nequip/data/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,5 @@
BATCH_KEY: Final[str] = "batch"
NUM_NODES_KEY: Final[str] = "num_atoms"

MODEL_DTYPE_KEY: Final[str] = "_model_dtype_indicator"

# Make a list of allowed keys
ALLOWED_KEYS: List[str] = [v for k, v in globals().items() if k.endswith("_KEY")]
10 changes: 3 additions & 7 deletions nequip/nn/_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class GraphModel(GraphModuleMixin, torch.nn.Module):
model_input_fields (Dict[str, Any]): input fields and their irreps
"""

type_names: List[str]
model_dtype: torch.dtype
model_input_fields: List[str]
type_names: List[str]

def __init__(
self,
Expand Down Expand Up @@ -52,14 +52,12 @@ def __init__(
f"Model has `{k}` in its irreps_in with irreps `{irreps}`, but `{k}` is missing from/has inconsistent irreps in model_input_fields of `{self.irreps_in.get(k, 'missing')}`"
)
self.model = model
# type names and model_dtype aren't actually used -- they're here for recording purposes
self.type_names = type_names
self.model_dtype = (
model_dtype if model_dtype is not None else torch.get_default_dtype()
)
self.model_input_fields = list(self.irreps_in.keys())
self.register_buffer(
"_model_dtype_example", torch.as_tensor(0.0, dtype=model_dtype)
)
self.type_names = type_names

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# restrict the input data to allowed keys to prevent the model from directly using the dict from the outside,
Expand All @@ -68,8 +66,6 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
for k in self.model_input_fields:
if k in data:
new_data[k] = data[k]
# Store the model dtype indicator tensor in all input data dicts
new_data[AtomicDataDict.MODEL_DTYPE_KEY] = self._model_dtype_example
return self.model(new_data)

@torch.jit.unused
Expand Down
25 changes: 11 additions & 14 deletions nequip/nn/embedding/_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
start=1.0,
end=self.num_bessels,
steps=self.num_bessels,
dtype=_GLOBAL_DTYPE,
).unsqueeze(
0
) # (1, num_bessel)
Expand All @@ -149,24 +150,22 @@ def __init__(
AtomicDataDict.EDGE_CUTOFF_KEY: "0e",
},
)
# i.e. `model_dtype`
self._output_dtype = torch.get_default_dtype()

def extra_repr(self) -> str:
return f"num_bessels={self.num_bessels}"

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
model_dtype = data.get(
AtomicDataDict.MODEL_DTYPE_KEY, data[AtomicDataDict.POSITIONS_KEY]
).dtype

# == Bessel basis ==
x = data[self.norm_length_field] # (num_edges, 1)
# (num_edges, 1), (1, num_bessel) -> (num_edges, num_bessel)
bessel = (torch.sinc(x * self.bessel_weights) * self.bessel_weights).to(
model_dtype
self._output_dtype
)

# == polynomial cutoff ==
cutoff = self.cutoff(x).to(model_dtype)
cutoff = self.cutoff(x).to(self._output_dtype)
data[AtomicDataDict.EDGE_CUTOFF_KEY] = cutoff

# == save product ==
Expand Down Expand Up @@ -211,15 +210,14 @@ def __init__(
self.sh = o3.SphericalHarmonics(
self.irreps_edge_sh, edge_sh_normalize, edge_sh_normalization
)
# i.e. `model_dtype`
self._output_dtype = torch.get_default_dtype()

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
model_dtype = data.get(
AtomicDataDict.MODEL_DTYPE_KEY, data[AtomicDataDict.POSITIONS_KEY]
).dtype
data = with_edge_vectors_(data, with_lengths=False)
edge_vec = data[AtomicDataDict.EDGE_VECTORS_KEY]
edge_sh = self.sh(edge_vec)
data[self.out_field] = edge_sh.to(model_dtype)
data[self.out_field] = edge_sh.to(self._output_dtype)
return data


Expand All @@ -237,13 +235,12 @@ def __init__(
self._init_irreps(
irreps_in=irreps_in, irreps_out={AtomicDataDict.EDGE_CUTOFF_KEY: "0e"}
)
# i.e. `model_dtype`
self._output_dtype = torch.get_default_dtype()

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
if AtomicDataDict.EDGE_CUTOFF_KEY not in data:
model_dtype = data.get(
AtomicDataDict.MODEL_DTYPE_KEY, data[AtomicDataDict.POSITIONS_KEY]
).dtype
x = data[self.norm_length_field]
cutoff = self.cutoff(x).to(model_dtype)
cutoff = self.cutoff(x).to(self._output_dtype)
data[AtomicDataDict.EDGE_CUTOFF_KEY] = cutoff
return data
5 changes: 2 additions & 3 deletions nequip/nn/embedding/_one_hot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,15 @@ def __init__(
AtomicDataDict.NODE_ATTRS_KEY
]
self._init_irreps(irreps_in=irreps_in, irreps_out=irreps_out)
self._output_dtype = torch.get_default_dtype()

def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
type_numbers = data[AtomicDataDict.ATOM_TYPE_KEY].view(-1)
one_hot = torch.nn.functional.one_hot(
type_numbers, num_classes=self.num_types
).to(
device=type_numbers.device,
dtype=data.get(
AtomicDataDict.MODEL_DTYPE_KEY, data[AtomicDataDict.POSITIONS_KEY]
).dtype,
dtype=self._output_dtype,
)
data[AtomicDataDict.NODE_ATTRS_KEY] = one_hot
if self.set_features:
Expand Down
21 changes: 14 additions & 7 deletions tests/unit/nn/test_embed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from e3nn.util.test import assert_auto_jitable

from nequip.utils.test import assert_AtomicData_equivariant
from nequip.nn import SequentialGraphNetwork
from nequip.nn import SequentialGraphNetwork, GraphModel
from nequip.nn.embedding import (
PolynomialCutoff,
OneHotAtomEncoding,
Expand All @@ -14,26 +14,29 @@

def test_onehot(model_dtype, CH3CHO):
_, data = CH3CHO
with torch_default_dtype(dtype_from_name(model_dtype)):
mdtype = dtype_from_name(model_dtype)
with torch_default_dtype(mdtype):
oh = OneHotAtomEncoding(
type_names=["A", "B", "C"],
)
assert_auto_jitable(oh)
assert_AtomicData_equivariant(oh, data)
assert_AtomicData_equivariant(_wrap(oh, mdtype), data)


def test_spharm(model_dtype, CH3CHO):
_, data = CH3CHO
with torch_default_dtype(dtype_from_name(model_dtype)):
mdtype = dtype_from_name(model_dtype)
with torch_default_dtype(mdtype):
sph = SphericalHarmonicEdgeAttrs(irreps_edge_sh="0e + 1o + 2e")
assert_auto_jitable(sph)
assert_AtomicData_equivariant(sph, data)
assert_AtomicData_equivariant(_wrap(sph, mdtype), data)


def test_radial_basis(model_dtype, CH3CHO):
_, data = CH3CHO

with torch_default_dtype(dtype_from_name(model_dtype)):
mdtype = dtype_from_name(model_dtype)
with torch_default_dtype(mdtype):
rad = SequentialGraphNetwork(
{
"edge_norm": EdgeLengthNormalizer(r_max=5.0, type_names=[0, 1, 2]),
Expand All @@ -42,4 +45,8 @@ def test_radial_basis(model_dtype, CH3CHO):
)
assert_auto_jitable(rad.edge_norm)
assert_auto_jitable(rad.bessel)
assert_AtomicData_equivariant(rad, data)
assert_AtomicData_equivariant(_wrap(rad, mdtype), data)


def _wrap(module, dtype):
return GraphModel(module, ["A", "B", "C"], dtype)

0 comments on commit ece09b5

Please sign in to comment.