Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/ACEsuit/mace into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 1, 2024
2 parents 787cda9 + 36bb33b commit 0764dab
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
NonLinearReadoutBlock,
RadialEmbeddingBlock,
RealAgnosticAttResidualInteractionBlock,
RealAgnosticDensityInteractionBlock,
RealAgnosticDensityResidualInteractionBlock,
RealAgnosticInteractionBlock,
RealAgnosticResidualInteractionBlock,
ResidualElementDependentInteractionBlock,
Expand Down Expand Up @@ -56,6 +58,8 @@
"RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
"RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
"RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock,
"RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock,
}

scaling_classes: Dict[str, Callable] = {
Expand Down
175 changes: 175 additions & 0 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,181 @@ def forward(
) # [n_nodes, channels, (lmax + 1)**2]


@compile_mode("script")
class RealAgnosticDensityInteractionBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)

# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)

# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.irreps_out, self.node_attrs_irreps, self.irreps_out
)
self.reshape = reshape_irreps(self.irreps_out)

# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)
# Reshape
self.reshape = reshape_irreps(self.irreps_out)

def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / (density + 1)
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]


@compile_mode("script")
class RealAgnosticDensityResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)

# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu, # gate
)

# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)

# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
self.reshape = reshape_irreps(self.irreps_out)

# Density normalization
self.density_fn = nn.FullyConnectedNet(
[input_dim]
+ [
1,
],
torch.nn.functional.silu,
)

# Reshape
self.reshape = reshape_irreps(self.irreps_out)

def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
edge_density = torch.tanh(self.density_fn(edge_feats) ** 2)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
density = scatter_sum(
src=edge_density, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, 1]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / (density + 1)
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]


@compile_mode("script")
class RealAgnosticAttResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
Expand Down
4 changes: 4 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"RealAgnosticResidualInteractionBlock",
"RealAgnosticAttResidualInteractionBlock",
"RealAgnosticInteractionBlock",
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
],
)
parser.add_argument(
Expand All @@ -163,6 +165,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
choices=[
"RealAgnosticResidualInteractionBlock",
"RealAgnosticInteractionBlock",
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
],
)
parser.add_argument(
Expand Down

0 comments on commit 0764dab

Please sign in to comment.