diff --git a/requirements.txt b/requirements.txt index df9a237..683eb65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,4 @@ matplotlib tqdm networkx ninja -jinja2 -class-resolver +jinja2 \ No newline at end of file diff --git a/setup.py b/setup.py index 21356d3..a924387 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ "networkx", "ninja", "jinja2", - "class-resolver", ], python_requires=">=3.7,<3.9", classifiers=[ diff --git a/torchdrug/layers/__init__.py b/torchdrug/layers/__init__.py index ea18488..22c3b60 100644 --- a/torchdrug/layers/__init__.py +++ b/torchdrug/layers/__init__.py @@ -3,7 +3,7 @@ from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \ NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv from .pool import DiffPool, MinCutPool -from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort, readout_resolver, Readout +from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort from .flow import ConditionalFlow from .sampler import NodeSampler, EdgeSampler from . import distribution, functional @@ -23,7 +23,7 @@ "MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv", "NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv", "DiffPool", "MinCutPool", - "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout", + "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "ConditionalFlow", "NodeSampler", "EdgeSampler", "distribution", "functional", diff --git a/torchdrug/layers/readout.py b/torchdrug/layers/readout.py index 66ed64e..3680d29 100644 --- a/torchdrug/layers/readout.py +++ b/torchdrug/layers/readout.py @@ -1,14 +1,9 @@ import torch from torch import nn from torch_scatter import scatter_mean, scatter_add, scatter_max -from class_resolver import ClassResolver -class Readout(nn.Module): - """A base class for readouts.""" - - -class MeanReadout(Readout): +class MeanReadout(nn.Module): """Mean readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -26,7 +21,7 @@ def forward(self, graph, input): return output -class SumReadout(Readout): +class SumReadout(nn.Module): """Sum readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -44,7 +39,7 @@ def forward(self, graph, input): return output -class MaxReadout(Readout): +class MaxReadout(nn.Module): """Max readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -62,12 +57,6 @@ def forward(self, graph, input): return output -readout_resolver = ClassResolver.from_subclasses( - Readout, - default=SumReadout, -) - - class Softmax(nn.Module): """Softmax operator over graphs with variadic sizes.""" diff --git a/torchdrug/models/chebnet.py b/torchdrug/models/chebnet.py index 521aaef..86d2aef 100644 --- a/torchdrug/models/chebnet.py +++ b/torchdrug/models/chebnet.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.ChebNet") @@ -31,7 +29,7 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(ChebyshevConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gat.py b/torchdrug/models/gat.py index db2b5e7..c22e848 100644 --- a/torchdrug/models/gat.py +++ b/torchdrug/models/gat.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GAT") @@ -31,7 +29,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, - batch_norm=False, activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): super(GraphAttentionNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega self.layers.append(layers.GraphAttentionConv(self.dims[i], self.dims[i + 1], edge_input_dim, num_head, negative_slope, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gcn.py b/torchdrug/models/gcn.py index 679ae28..352c39b 100644 --- a/torchdrug/models/gcn.py +++ b/torchdrug/models/gcn.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GCN") @@ -29,7 +27,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(GraphConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -44,7 +42,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, for i in range(len(self.dims) - 1): self.layers.append(layers.GraphConv(self.dims[i], self.dims[i + 1], edge_input_dim, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ @@ -103,7 +108,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(RelationalGraphConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -120,7 +125,14 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index f32e2ca..f0d99de 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GIN") @@ -32,8 +30,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, - short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, - readout: Hint[Readout] = "sum"): + short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): super(GraphIsomorphismNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -50,7 +47,14 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml self.layers.append(layers.GraphIsomorphismConv(self.dims[i], self.dims[i + 1], edge_input_dim, layer_hidden_dims, eps, learn_eps, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/neuralfp.py b/torchdrug/models/neuralfp.py index a1f9512..b4e5979 100644 --- a/torchdrug/models/neuralfp.py +++ b/torchdrug/models/neuralfp.py @@ -1,13 +1,11 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torch.nn import functional as F from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.NeuralFP") @@ -31,7 +29,7 @@ class NeuralFingerprint(nn.Module, core.Configurable): """ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(NeuralFingerprint, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -49,7 +47,14 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor batch_norm, activation)) self.linears.append(nn.Linear(self.dims[i + 1], output_dim)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/schnet.py b/torchdrug/models/schnet.py index 7861a61..1a7cf09 100644 --- a/torchdrug/models/schnet.py +++ b/torchdrug/models/schnet.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.SchNet") @@ -31,7 +29,7 @@ class SchNet(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, - batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout: Hint[Readout] = "sum"): + batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout="sum"): super(SchNet, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff, num_gaussian, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/tasks/pretrain.py b/torchdrug/tasks/pretrain.py index 8b9460e..a26bc5a 100644 --- a/torchdrug/tasks/pretrain.py +++ b/torchdrug/tasks/pretrain.py @@ -1,14 +1,13 @@ import copy import torch -from class_resolver import Hint from torch import nn from torch.nn import functional as F from torch_scatter import scatter_max, scatter_min from torchdrug import core, tasks, layers from torchdrug.data import constant -from torchdrug.layers import functional, readout_resolver, Readout +from torchdrug.layers import functional from torchdrug.core import Registry as R @@ -173,7 +172,7 @@ class ContextPrediction(tasks.Task, core.Configurable): readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ - def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1): + def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1): super(ContextPrediction, self).__init__() self.model = model self.k = k @@ -187,7 +186,14 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Rea else: self.context_model = context_model - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def substruct_and_context(self, graph): center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()