Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Feb 8, 2022
1 parent 4d74e5d commit 50adc02
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 33 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ matplotlib
tqdm
networkx
ninja
jinja2
jinja2
class-resolver
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"networkx",
"ninja",
"jinja2",
"class-resolver",
],
python_requires=">=3.7,<3.9",
classifiers=[
Expand Down
2 changes: 1 addition & 1 deletion torchdrug/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv",
"NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv",
"DiffPool", "MinCutPool",
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort",
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout",
"ConditionalFlow",
"NodeSampler", "EdgeSampler",
"distribution", "functional",
Expand Down
13 changes: 5 additions & 8 deletions torchdrug/models/chebnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.ChebNet")
Expand All @@ -25,11 +27,11 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum`` and ``mean``.
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout="sum"):
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
super(ChebyshevConvolutionalNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -45,12 +47,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F
self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k,
batch_norm, activation))

if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)
self.readout = readout_resolver.make(readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
2 changes: 2 additions & 0 deletions torchdrug/models/gat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GAT")
Expand Down
15 changes: 4 additions & 11 deletions torchdrug/models/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch import nn

from torchdrug import core, layers
from torchdrug.layers import readout_resolver, Readout
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GCN")
Expand Down Expand Up @@ -99,11 +99,11 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout="sum"):
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
super(RelationalGraphConvolutionalNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -120,14 +120,7 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim,
batch_norm, activation))

if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)
self.readout = readout_resolver.make(readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
6 changes: 4 additions & 2 deletions torchdrug/models/gin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GIN")
Expand All @@ -26,12 +28,12 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False,
readout="sum"):
readout: Hint[Readout] = "sum"):
super(GraphIsomorphismNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand Down
6 changes: 4 additions & 2 deletions torchdrug/models/neuralfp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn
from torch.nn import functional as F

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.NeuralFP")
Expand All @@ -25,11 +27,11 @@ class NeuralFingerprint(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout="sum"):
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
super(NeuralFingerprint, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand Down
3 changes: 3 additions & 0 deletions torchdrug/models/schnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.SchNet")
Expand All @@ -25,6 +27,7 @@ class SchNet(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True,
Expand Down
14 changes: 6 additions & 8 deletions torchdrug/tasks/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import copy

import torch
from class_resolver import Hint
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter_max, scatter_min

from torchdrug import core, tasks, layers
from torchdrug.data import constant
from torchdrug.layers import functional
from torchdrug.layers import functional, readout_resolver, Readout
from torchdrug.core import Registry as R


Expand Down Expand Up @@ -169,9 +170,10 @@ class ContextPrediction(tasks.Task, core.Configurable):
r2 (int, optional): outer radius for context graphs
readout (nn.Module, optional): readout function over context anchor nodes
num_negative (int, optional): number of negative samples per positive sample
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1):
def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1):
super(ContextPrediction, self).__init__()
self.model = model
self.k = k
Expand All @@ -184,12 +186,8 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", n
self.context_model = copy.deepcopy(model)
else:
self.context_model = context_model
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

self.readout = readout_resolver.make(readout)

def substruct_and_context(self, graph):
center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()
Expand Down

0 comments on commit 50adc02

Please sign in to comment.