diff --git a/requirements.txt b/requirements.txt index df9a237..3c7c495 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,3 @@ tqdm networkx ninja jinja2 -class-resolver diff --git a/torchdrug/layers/__init__.py b/torchdrug/layers/__init__.py index ea18488..22c3b60 100644 --- a/torchdrug/layers/__init__.py +++ b/torchdrug/layers/__init__.py @@ -3,7 +3,7 @@ from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \ NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv from .pool import DiffPool, MinCutPool -from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort, readout_resolver, Readout +from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort from .flow import ConditionalFlow from .sampler import NodeSampler, EdgeSampler from . import distribution, functional @@ -23,7 +23,7 @@ "MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv", "NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv", "DiffPool", "MinCutPool", - "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout", + "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "ConditionalFlow", "NodeSampler", "EdgeSampler", "distribution", "functional", diff --git a/torchdrug/layers/readout.py b/torchdrug/layers/readout.py index 66ed64e..3680d29 100644 --- a/torchdrug/layers/readout.py +++ b/torchdrug/layers/readout.py @@ -1,14 +1,9 @@ import torch from torch import nn from torch_scatter import scatter_mean, scatter_add, scatter_max -from class_resolver import ClassResolver -class Readout(nn.Module): - """A base class for readouts.""" - - -class MeanReadout(Readout): +class MeanReadout(nn.Module): """Mean readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -26,7 +21,7 @@ def forward(self, graph, input): return output -class SumReadout(Readout): +class SumReadout(nn.Module): """Sum readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -44,7 +39,7 @@ def forward(self, graph, input): return output -class MaxReadout(Readout): +class MaxReadout(nn.Module): """Max readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -62,12 +57,6 @@ def forward(self, graph, input): return output -readout_resolver = ClassResolver.from_subclasses( - Readout, - default=SumReadout, -) - - class Softmax(nn.Module): """Softmax operator over graphs with variadic sizes.""" diff --git a/torchdrug/models/chebnet.py b/torchdrug/models/chebnet.py index 521aaef..86d2aef 100644 --- a/torchdrug/models/chebnet.py +++ b/torchdrug/models/chebnet.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.ChebNet") @@ -31,7 +29,7 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(ChebyshevConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gat.py b/torchdrug/models/gat.py index db2b5e7..c22e848 100644 --- a/torchdrug/models/gat.py +++ b/torchdrug/models/gat.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GAT") @@ -31,7 +29,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, - batch_norm=False, activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): super(GraphAttentionNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega self.layers.append(layers.GraphAttentionConv(self.dims[i], self.dims[i + 1], edge_input_dim, num_head, negative_slope, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gcn.py b/torchdrug/models/gcn.py index 679ae28..c105641 100644 --- a/torchdrug/models/gcn.py +++ b/torchdrug/models/gcn.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GCN") @@ -29,7 +27,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(GraphConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -44,7 +42,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, for i in range(len(self.dims) - 1): self.layers.append(layers.GraphConv(self.dims[i], self.dims[i + 1], edge_input_dim, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index f32e2ca..f0d99de 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GIN") @@ -32,8 +30,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, - short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, - readout: Hint[Readout] = "sum"): + short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): super(GraphIsomorphismNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -50,7 +47,14 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml self.layers.append(layers.GraphIsomorphismConv(self.dims[i], self.dims[i + 1], edge_input_dim, layer_hidden_dims, eps, learn_eps, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/neuralfp.py b/torchdrug/models/neuralfp.py index a1f9512..b4e5979 100644 --- a/torchdrug/models/neuralfp.py +++ b/torchdrug/models/neuralfp.py @@ -1,13 +1,11 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torch.nn import functional as F from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.NeuralFP") @@ -31,7 +29,7 @@ class NeuralFingerprint(nn.Module, core.Configurable): """ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(NeuralFingerprint, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -49,7 +47,14 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor batch_norm, activation)) self.linears.append(nn.Linear(self.dims[i + 1], output_dim)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/schnet.py b/torchdrug/models/schnet.py index 7861a61..1a7cf09 100644 --- a/torchdrug/models/schnet.py +++ b/torchdrug/models/schnet.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.SchNet") @@ -31,7 +29,7 @@ class SchNet(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, - batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout: Hint[Readout] = "sum"): + batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout="sum"): super(SchNet, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff, num_gaussian, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """