Skip to content

Commit

Permalink
add some documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Jun 4, 2022
1 parent 957cec7 commit ce75ce4
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 34 deletions.
Binary file added asset/graph/correct_reference.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asset/graph/inverse_edge.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asset/graph/wrong_reference.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions doc/source/api/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ Variadic

.. autofunction:: variadic_sample

.. autofunction:: variadic_meshgrid

.. autofunction:: variadic_to_padded

.. autofunction:: padded_to_variadic

Tensor Reduction
^^^^^^^^^^^^^^^^
.. autofunction:: masked_mean
Expand Down
18 changes: 15 additions & 3 deletions doc/source/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,21 @@ R2
^^
.. autofunction:: r2

Variadic Accuracy
^^^^^^^^^^^^^^^^^
.. autofunction:: variadic_accuracy
Accuracy
^^^^^^^^
.. autofunction:: accuracy

Matthews Correlation Coefficient
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofuction:: matthews_corrcoef

Pearson Correlation Coefficient
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pearsonr

Spearman Correlation Coefficient
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: spearmanr


Chemical Metrics
Expand Down
33 changes: 2 additions & 31 deletions doc/source/notes/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ the result is not desired. The edges are masked out correctly, but the values of
inverse indexes are wrong.

.. code:: python
with graph.edge():
graph.inv_edge_index = torch.tensor(inv_edge_index)
g1 = graph.edge_mask([0, 2, 3])
Expand All @@ -55,34 +56,4 @@ since the corresponding inverse edge has been masked out.
:width: 33%

We can use ``graph.node_reference()`` and ``graph.graph_reference()`` for references
to nodes and graphs respectively.

Use Cases in Proteins
---------------------

In :class:`data.Protein`, the mapping ``atom2residue`` is implemented as
references. The intuition is that references enable flexible indexing on either atoms
or residues, while maintaining the correspondence between two views.

The following example shows how to track a specific residue with ``atom2residue`` in
the atom view. For a protein, we first create a mask for atoms in a glutamine (GLN).

.. code:: python
protein = data.Protein.from_sequence("KALKQMLDMG")
is_glutamine = protein.residue_type[protein.atom2residue] == protein.residue2id["GLN"]
with protein.node():
protein.is_glutamine = is_glutamine
We then apply a mask to the protein residue sequence. In the output protein,
``atom2residue`` is able to map the masked atoms back to the glutamine residue.

.. code:: python
p1 = protein[3:6]
residue_type = p1.residue_type[p1.atom2residue[p1.is_glutamine]]
print([p1.id2residue[r] for r in residue_type.tolist()])
.. code:: bash
['GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN']
to nodes and graphs respectively.
1 change: 1 addition & 0 deletions doc/source/notes/variadic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ Naturally, the prediction over nodes also forms a variadic tensor with ``num_nod
:func:`variadic_topk <torchdrug.layers.functional.variadic_topk>`,
:func:`variadic_randperm <torchdrug.layers.functional.variadic_randperm>`,
:func:`variadic_sample <torchdrug.layers.functional.variadic_sample>`,
:func:`variadic_meshgrid <torchdrug.layers.functional.variadic_meshgrid`,
:func:`variadic_softmax <torchdrug.layers.functional.variadic_softmax>`,
:func:`variadic_log_softmax <torchdrug.layers.functional.variadic_log_softmax>`,
:func:`variadic_cross_entropy <torchdrug.layers.functional.variadic_cross_entropy>`,
Expand Down
22 changes: 22 additions & 0 deletions torchdrug/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,17 @@ def edge_mask(self, index):
num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)

def line_graph(self):
"""
Construct a line graph of this graph.
The node feature of the line graph is inherited from the edge feature of the original graph.
In the line graph, each node corresponds to an edge in the original graph.
For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph,
there is a directed edge (a, b) -> (b, c) in the line graph.
Returns:
Graph
"""
node_in, node_out = self.edge_list.t()[:2]
edge_index = torch.arange(self.num_edge, device=self.device)
edge_in = edge_index[node_out.argsort()]
Expand Down Expand Up @@ -1627,6 +1638,17 @@ def subbatch(self, index):
return self.graph_mask(index, compact=True)

def line_graph(self):
"""
Construct a packed line graph of this packed graph.
The node features of the line graphs are inherited from the edge features of the original graphs.
In the line graph, each node corresponds to an edge in the original graph.
For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph,
there is a directed edge (a, b) -> (b, c) in the line graph.
Returns:
PackedGraph
"""
node_in, node_out = self.edge_list.t()[:2]
edge_index = torch.arange(self.num_edge, device=self.device)
edge_in = edge_index[node_out.argsort()]
Expand Down
38 changes: 38 additions & 0 deletions torchdrug/layers/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ def variadic_sort(input, size, descending=False):
input (Tensor): input of shape :math:`(B, ...)`
size (LongTensor): size of sets of shape :math:`(N,)`
descending (bool, optional): return ascending or descending order
Returns
(Tensor, LongTensor): sorted values and indexes
"""
index2sample = _size_to_index(size)
index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
Expand Down Expand Up @@ -445,6 +448,21 @@ def variadic_sample(input, size, num_sample):


def variadic_meshgrid(input1, size1, input2, size2):
"""
Compute the Cartesian product for two batches of sets with variadic sizes.
Suppose there are :math:`N` sets in each input,
and the sizes of all sets are summed to :math:`B_1` and :math:`B_2` respectively.
Parameters:
input1 (Tensor): input of shape :math:`(B_1, ...)`
size1 (LongTensor): size of :attr:`input1` of shape :math:`(N,)`
input2 (Tensor): input of shape :math:`(B_2, ...)`
size2 (LongTensor): size of :attr:`input2` of shape :math:`(N,)`
Returns
(Tensor, Tensor): the first and the second elements in the Cartesian product
"""
grid_size = size1 * size2
local_index = variadic_arange(grid_size)
local_inner_size = size2.repeat_interleave(grid_size)
Expand All @@ -456,6 +474,19 @@ def variadic_meshgrid(input1, size1, input2, size2):


def variadic_to_padded(input, size, value=0):
"""
Convert a variadic tensor to a padded tensor.
Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
Parameters:
input (Tensor): input of shape :math:`(B, ...)`
size (LongTensor): size of sets of shape :math:`(N,)`
value (scalar): fill value for padding
Returns:
(Tensor, BoolTensor): padded tensor and mask
"""
num_sample = len(size)
max_size = size.max()
starts = torch.arange(num_sample, device=size.device) * max_size
Expand All @@ -469,6 +500,13 @@ def variadic_to_padded(input, size, value=0):


def padded_to_variadic(padded, size):
"""
Convert a padded tensor to a variadic tensor.
Parameters:
padded (Tensor): padded tensor of shape :math:`(N, ...)`
size (LongTensor): size of sets of shape :math:`(N,)`
"""
num_sample, max_size = padded.shape[:2]
starts = torch.arange(num_sample, device=size.device) * max_size
ends = starts + size
Expand Down
1 change: 1 addition & 0 deletions torchdrug/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class TargetNormalize(object):
"""
Normalize the target values in a sample.
Parameters:
mean (dict of float): mean of targets
std (dict of float): standard deviation of targets
Expand Down

0 comments on commit ce75ce4

Please sign in to comment.