diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index d817d00d..5b94bbe4 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -30,6 +30,9 @@ jobs:
python -m pip install flake8 pytest-cov
if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi
if [ "${{ matrix.python-version }}" != "3.10" ]; then pip install mindspore==1.10.0; fi
+ - name: preparation for tests
+ run: |
+ python tests/test_a_star/prepare_for_test.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@@ -64,6 +67,9 @@ jobs:
python -m pip install --upgrade pip
python -m pip install flake8 pytest-cov
pip install -r tests/requirements_win_mac.txt
+ - name: preparation for tests
+ run: |
+ python tests/test_a_star/prepare_for_test.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@@ -74,7 +80,6 @@ jobs:
run: |
pytest --cov=pygmtools --cov-report=xml --backend=mindspore tests/test_classic_solvers.py
pytest --cov=pygmtools --cov-report=xml --cov-append
-
windows:
runs-on: windows-latest
@@ -94,6 +99,9 @@ jobs:
python -m pip install --upgrade pip
python -m pip install flake8 pytest-cov
python -m pip install -r tests\requirements_win_mac.txt
+ - name: preparation for tests
+ run: |
+ python tests/test_a_star/prepare_for_test.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@@ -103,4 +111,4 @@ jobs:
- name: Test with pytest. They are divided into two runs because MindSpore will interfer with Paddle.
run: |
pytest --cov=pygmtools --cov-report=xml --backend=mindspore tests/test_classic_solvers.py
- pytest --cov=pygmtools --cov-report=xml --cov-append
+ pytest --cov=pygmtools --cov-report=xml --cov-append
\ No newline at end of file
diff --git a/docs/images/astar.png b/docs/images/astar.png
new file mode 100644
index 00000000..b21c5a64
Binary files /dev/null and b/docs/images/astar.png differ
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 5c30f6f0..cdf54f14 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -15,4 +15,5 @@ matplotlib
networkx
scikit-learn
wget
+networkx==2.8.8
pygmtools
diff --git a/examples/data/graph1.graphml b/examples/data/graph1.graphml
new file mode 100644
index 00000000..fa9bdbbb
--- /dev/null
+++ b/examples/data/graph1.graphml
@@ -0,0 +1,61 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/examples/data/graph2.graphml b/examples/data/graph2.graphml
new file mode 100644
index 00000000..228ec85a
--- /dev/null
+++ b/examples/data/graph2.graphml
@@ -0,0 +1,67 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/pygmtools/__init__.py b/pygmtools/__init__.py
index 6899fbfc..b2b3534a 100644
--- a/pygmtools/__init__.py
+++ b/pygmtools/__init__.py
@@ -10,9 +10,9 @@
from .benchmark import Benchmark
from .linear_solvers import sinkhorn, hungarian
-from .classic_solvers import rrwm, sm, ipfp
+from .classic_solvers import rrwm, sm, ipfp, astar
from .multi_graph_solvers import cao, mgm_floyd, gamgm
-from .neural_solvers import pca_gm, ipca_gm, cie, ngm
+from .neural_solvers import pca_gm, ipca_gm, cie, ngm, genn_astar
import pygmtools.utils as utils
BACKEND = 'numpy'
__version__ = '0.3.8'
diff --git a/pygmtools/astar/a_star.tar.gz b/pygmtools/astar/a_star.tar.gz
new file mode 100644
index 00000000..a63c1c76
Binary files /dev/null and b/pygmtools/astar/a_star.tar.gz differ
diff --git a/pygmtools/classic_solvers.py b/pygmtools/classic_solvers.py
index 914a514b..d42fc9e9 100644
--- a/pygmtools/classic_solvers.py
+++ b/pygmtools/classic_solvers.py
@@ -22,8 +22,9 @@
import importlib
import pygmtools
-from pygmtools.utils import NOT_IMPLEMENTED_MSG, _check_shape, _get_shape, _unsqueeze, _squeeze, _check_data_type
-
+from pygmtools.utils import NOT_IMPLEMENTED_MSG, _check_shape, _get_shape,\
+ _unsqueeze, _squeeze, _check_data_type,from_numpy
+import numpy as np
def sm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
max_iter: int=50,
@@ -1061,6 +1062,117 @@ def ipfp(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
return result
+def astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, dropout=0, beam_width=0,
+ trust_fact=1, no_pred_size=0, backend=None):
+ r"""
+ ASTAR solver for graph matching (Lawler's QAP).
+ The **ASTAR** solver finds the optimal match between two graphs through heuristic search.
+
+ :param feat1: :math:`(b\times n_1 \times d)` input feature of graph1
+ :param feat2: :math:`(b\times n_2 \times d)` input feature of graph2
+ :param A1: :math:`(b\times n_1 \times n_1)` input adjacency matrix of graph1
+ :param A2: :math:`(b\times n_2 \times n_2)` input adjacency matrix of graph2
+ :param n1: :math:`(b)` number of nodes in graph1. Optional if all equal to :math:`n_1`
+ :param n2: :math:`(b)` number of nodes in graph2. Optional if all equal to :math:`n_2`
+ :param channel: (default: None) Channel size of the input layer. If given, it must match the feature dimension (d) of feat1, feat2.
+ If not given, it will be defined by the feature dimension (d) of feat1, feat2.
+ Ignored if the network object isgiven (ignored if network!=None)
+ :param dropout: (default: 0) Dropout probability
+ :param beam_width: (default: 0) Size of beam-search witdh (0 = no beam).
+ :param trust_fact: (default: 1) The trust factor on GNN prediction (0 = no GNN).
+ :param no_pred_size: (default: 0) If the smaller graph has no more than x nodes, stop using heuristics.
+ :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
+ :return: :math:`(b\times n_1 \times n_2)` the doubly-stochastic matching matrix
+
+ .. note::
+ This function also supports non-batched input, by ignoring all batch dimensions in the input tensors.
+
+ .. dropdown:: PyTorch Example
+
+ ::
+
+ >>> import torch
+ >>> import pygmtools as pygm
+ >>> pygm.BACKEND = 'pytorch'
+ >>> _ = torch.manual_seed(1)
+
+ # Generate a batch of isomorphic graphs
+ >>> batch_size = 10
+ >>> nodes_num = 4
+ >>> channel = 36
+
+ >>> X_gt = torch.zeros(batch_size, nodes_num, nodes_num)
+ >>> X_gt[:, torch.arange(0, nodes_num, dtype=torch.int64), torch.randperm(nodes_num)] = 1
+ >>> A1 = 1. * (torch.rand(batch_size, nodes_num, nodes_num) > 0.5)
+ >>> torch.diagonal(A1, dim1=1, dim2=2)[:] = 0 # discard self-loop edges
+ >>> A2 = torch.bmm(torch.bmm(X_gt.transpose(1, 2), A1), X_gt)
+ >>> feat1 = torch.rand(batch_size, nodes_num, channel) - 0.5
+ >>> feat2 = torch.bmm(X_gt.transpose(1, 2), feat1)
+ >>> n1 = n2 = torch.tensor([nodes_num] * batch_size)
+
+ # Match by ASTAR (load pretrained model)
+ >>> X = pygm.astar(feat1, feat2, A1, A2, n1, n2)
+ >>> (X * X_gt).sum() / X_gt.sum()# accuracy
+ tensor(1.)
+
+ # This function also supports non-batched input, by ignoring all batch dimensions in the input tensors.
+ >>> part_f1 = feat1[0]
+ >>> part_f2 = feat2[0]
+ >>> part_A1 = A1[0]
+ >>> part_A2 = A2[0]
+ >>> part_X_gt = X_gt[0]
+ >>> part_X = pygm.astar(part_f1, part_f2, part_A1, part_A2)
+
+ >>> part_X.shape
+ torch.Size([4, 4])
+
+ >>> (part_X * part_X_gt).sum() / part_X_gt.sum()# accuracy
+ tensor(1.)
+ """
+ if backend is None:
+ backend = pygmtools.BACKEND
+ non_batched_input = False
+ if feat1 is not None: # if feat1 is None, this function skips the forward pass and only returns a network object
+ for _ in (feat1, feat2, A1, A2):
+ _check_data_type(_, backend)
+
+ if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]):
+ feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)]
+ if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend)
+ if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend)
+ non_batched_input = True
+ elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]):
+ non_batched_input = False
+ else:
+ raise ValueError(
+ f'the input arguments feat1, feat2, A1, A2 are expected to be all 2-dimensional or 3-dimensional, got '
+ f'feat1:{len(_get_shape(feat1, backend))}dims, feat2:{len(_get_shape(feat2, backend))}dims, '
+ f'A1:{len(_get_shape(A1, backend))}dims, A2:{len(_get_shape(A2, backend))}dims!')
+
+ if not (_get_shape(feat1, backend)[0] == _get_shape(feat2, backend)[0] == _get_shape(A1, backend)[0] == _get_shape(A2, backend)[0])\
+ or not (_get_shape(feat1, backend)[1] == _get_shape(A1, backend)[1] == _get_shape(A1, backend)[2])\
+ or not (_get_shape(feat2, backend)[1] == _get_shape(A2, backend)[1] == _get_shape(A2, backend)[2])\
+ or not (_get_shape(feat1, backend)[2] == _get_shape(feat2, backend)[2]):
+ raise ValueError(
+ f'the input dimensions do not match. Got feat1:{_get_shape(feat1, backend)}, '
+ f'feat2:{_get_shape(feat2, backend)}, A1:{_get_shape(A1, backend)}, A2:{_get_shape(A2, backend)}!')
+ if n1 is not None: _check_data_type(n1, 'n1', backend)
+ if n2 is not None: _check_data_type(n2, 'n2', backend)
+
+ args = (feat1, feat2, A1, A2, n1, n2, channel, dropout, beam_width, trust_fact, no_pred_size)
+ try:
+ mod = importlib.import_module(f'pygmtools.{backend}_backend')
+ fn = mod.astar
+ except (ModuleNotFoundError, AttributeError):
+ raise NotImplementedError(
+ NOT_IMPLEMENTED_MSG.format(backend)
+ )
+
+ result = fn(*args)
+ match_mat = _squeeze(result[0], 0, backend) if non_batched_input else result[0]
+ return match_mat
+
+
def __check_gm_arguments(n1, n2, n1max, n2max):
if n1 is None and n1max is None:
raise ValueError('at least one of the following arguments are required: n1 and n1max.')
diff --git a/pygmtools/neural_solvers.py b/pygmtools/neural_solvers.py
index fa40bdad..5be71048 100644
--- a/pygmtools/neural_solvers.py
+++ b/pygmtools/neural_solvers.py
@@ -1271,3 +1271,211 @@ def ngm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None,
return match_mat, result[1]
else:
return match_mat
+
+
+def genn_astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, filters_1=64, filters_2=32, filters_3=16,
+ tensor_neurons=16, dropout=0, beam_width=0, trust_fact=1, no_pred_size=0,
+ network=None, return_network=False, pretrain='AIDS700nef', backend=None):
+ r"""
+ The **GENN-ASTAR** (Graph Edit Neural Network Astar) solver for graph matching based on the combination of traditional A-star and Neural Network.
+ This algorithm replaces the heuristic prediction module in the traditional A-star algorithm with **GNN** (Graph Neural Network) model,
+ greatly improving the efficiency of A-star algorithm while ensuring a certain degree of accuracy.
+ During the search process, the algorithm prioritizes the next search direction based on the distance between the current state and the target state.
+ At each step of the search, the algorithm uses a predicted probability distribution of node pairs for matching.
+
+ See the following picture to better understand the workflow of the algorithm:
+
+ .. image:: ../../images/astar.png
+
+ See the following paper for more technical details:
+ `"Combinatorial Learning of Graph Edit Distance via Dynamic Embedding"
+ `_
+
+
+ :param feat1: :math:`(b\times n_1 \times d)` input feature of graph1
+ :param feat2: :math:`(b\times n_2 \times d)` input feature of graph2
+ :param A1: :math:`(b\times n_1 \times n_1)` input adjacency matrix of graph1
+ :param A2: :math:`(b\times n_2 \times n_2)` input adjacency matrix of graph2
+ :param n1: :math:`(b)` number of nodes in graph1. Optional if all equal to :math:`n_1`
+ :param n2: :math:`(b)` number of nodes in graph2. Optional if all equal to :math:`n_2`
+ :param channel: (default: None) Channel size of the input layer. If given, it must match the feature dimension (d) of feat1, feat2.
+ If not given, it will be defined by the feature dimension (d) of feat1, feat2.
+ Ignored if the network object isgiven (ignored if network!=None)
+ :param filters_1: (default: 64) Filters (neurons) in 1st convolution.
+ :param filters_2: (default: 32) Filters (neurons) in 2nd convolution.
+ :param filters_3: (default: 16) Filters (neurons) in 2nd convolution.
+ :param tensor_neurons: (default: 16) Neurons in tensor network layer.
+ :param dropout: (default: 0) Dropout probability
+ :param beam_width: (default: 0) Size of beam-search witdh (0 = no beam).
+ :param trust_fact: (default: 1) The trust factor on GNN prediction (0 = no GNN).
+ :param no_pred_size: (default: 0) If the smaller graph has no more than x nodes, stop using heuristics.
+ :param network: (default: None) The network object. If None, a new network object will be created, and load the
+ model weights specified in ``pretrain`` argument.
+ :param return_network: (default: False) Return the network object (saving model construction time if calling the
+ model multiple times).
+ :param pretrain: (default: 'AIDS700nef') If ``network==None``, the pretrained model weights to be loaded. Available
+ pretrained weights: ``AIDS700nef`` (channel=36), ``LINUX`` (channel=8),
+ or ``False`` (no pretraining).
+ :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
+ :return: if ``return_network==False``, :math:`(b\times n_1 \times n_2)` the doubly-stochastic matching matrix
+
+ if ``return_network==True``, :math:`(b\times n_1 \times n_2)` the doubly-stochastic matching matrix,
+ the network object
+
+ .. note::
+ You may need a proxy to load the pretrained weights if Google drive is not accessible in your contry/region.
+ You may also download the pretrained models manually and put them at ``~/.cache/pygmtools`` (for Linux).
+
+ `[google drive] `_
+
+ .. note::
+ This function also supports non-batched input, by ignoring all batch dimensions in the input tensors.
+
+ .. dropdown:: PyTorch Example
+
+ ::
+
+ >>> import torch
+ >>> import pygmtools as pygm
+ >>> pygm.BACKEND = 'pytorch'
+ >>> _ = torch.manual_seed(1)
+
+ # Generate a batch of isomorphic graphs
+ >>> batch_size = 10
+ >>> nodes_num = 4
+ >>> channel = 36
+
+ >>> X_gt = torch.zeros(batch_size, nodes_num, nodes_num)
+ >>> X_gt[:, torch.arange(0, nodes_num, dtype=torch.int64), torch.randperm(nodes_num)] = 1
+ >>> A1 = 1. * (torch.rand(batch_size, nodes_num, nodes_num) > 0.5)
+ >>> torch.diagonal(A1, dim1=1, dim2=2)[:] = 0 # discard self-loop edges
+ >>> A2 = torch.bmm(torch.bmm(X_gt.transpose(1, 2), A1), X_gt)
+ >>> feat1 = torch.rand(batch_size, nodes_num, channel) - 0.5
+ >>> feat2 = torch.bmm(X_gt.transpose(1, 2), feat1)
+ >>> n1 = n2 = torch.tensor([nodes_num] * batch_size)
+
+ # Match by GENN-ASTAR (load pretrained model)
+ >>> X, net = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, return_network=True)
+ Downloading to ~/.cache/pygmtools/best_genn_AIDS700nef_gcn_astar.pt...
+ >>> (X * X_gt).sum() / X_gt.sum()# accuracy
+ tensor(1.)
+
+ # Pass the net object to avoid rebuilding the model agian
+ >>> X = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, network=net)
+
+ # This function also supports non-batched input, by ignoring all batch dimensions in the input tensors.
+ >>> part_f1 = feat1[0]
+ >>> part_f2 = feat2[0]
+ >>> part_A1 = A1[0]
+ >>> part_A2 = A2[0]
+ >>> part_X_gt = X_gt[0]
+ >>> part_X = pygm.genn_astar(part_f1, part_f2, part_A1, part_A2, return_network=False)
+
+ >>> part_X.shape
+ torch.Size([4, 4])
+
+ >>> (part_X * part_X_gt).sum() / part_X_gt.sum()# accuracy
+ tensor(1.)
+
+ # You may also load other pretrained weights
+ # However, it should be noted that each pretrained set supports different node feature dimensions
+ # AIDS700nef(Default): channel = 36
+ # LINUX: channel = 8
+ # Generate a batch of isomorphic graphs
+ >>> batch_size = 10
+ >>> nodes_num = 4
+ >>> channel = 8
+
+ >>> X_gt = torch.zeros(batch_size, nodes_num, nodes_num)
+ >>> X_gt[:, torch.arange(0, nodes_num, dtype=torch.int64), torch.randperm(nodes_num)] = 1
+ >>> A1 = 1. * (torch.rand(batch_size, nodes_num, nodes_num) > 0.5)
+ >>> torch.diagonal(A1, dim1=1, dim2=2)[:] = 0 # discard self-loop edges
+ >>> A2 = torch.bmm(torch.bmm(X_gt.transpose(1, 2), A1), X_gt)
+ >>> feat1 = torch.rand(batch_size, nodes_num, channel) - 0.5
+ >>> feat2 = torch.bmm(X_gt.transpose(1, 2), feat1)
+ >>> n1 = n2 = torch.tensor([nodes_num] * batch_size)
+
+ >>> X, net = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, pretrain='LINUX', return_network=True)
+ Downloading to ~/.cache/pygmtools/best_genn_LINUX_gcn_astar.pt...
+
+ >>> (X * X_gt).sum() / X_gt.sum()# accuracy
+ tensor(1.)
+
+ # When the input node feature dimension is different from the one supported by pre training,
+ # you can still use the solver, but the solver will provide a warning
+ >>> X, net = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, return_network=True, pretrain='AIDS700nef')
+ Warning: Skip key(s) in state_dict: "convolution_1.weight".
+ Warning: Pretrain AIDS700nef does not support the parameters you entered,
+ Supported parameters: ( channel:36, filters:(64,32,16) ),
+ Input parameters: ( channel:8, filters:(64,32,16) )
+
+ # You may configure your own model and integrate the model into a deep learning pipeline. For example:
+ >>> net = pygm.utils.get_network(pygm.genn_astar, channel = 1000, filters_1 = 1024, filters_2 = 256, filters_3 = 128, pretrain=False)
+ >>> optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
+ # feat1/feat2 may be outputs by other neural networks
+ >>> X = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, network=net)
+ >>> loss = pygm.utils.permutation_loss(X, X_gt)
+ >>> loss.backward()
+ >>> optimizer.step()
+
+ .. note::
+
+ If you find this model useful in your research, please cite:
+
+ ::
+
+ @inproceedings{WangCVPR21,
+ author={Runzhong Wang, Tianqi Zhang, Tianshu Yu, Junchi Yan, Xiaokang Yang},
+ booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ title={Combinatorial Learning of Graph Edit Distance via Dynamic Embedding},
+ year={2021},
+ }
+ """
+
+ if backend is None:
+ backend = pygmtools.BACKEND
+ non_batched_input = False
+ if feat1 is not None: # if feat1 is None, this function skips the forward pass and only returns a network object
+ for _ in (feat1, feat2, A1, A2):
+ _check_data_type(_, backend)
+
+ if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]):
+ feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)]
+ if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend)
+ if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend)
+ non_batched_input = True
+ elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]):
+ non_batched_input = False
+ else:
+ raise ValueError(
+ f'the input arguments feat1, feat2, A1, A2 are expected to be all 2-dimensional or 3-dimensional, got '
+ f'feat1:{len(_get_shape(feat1, backend))}dims, feat2:{len(_get_shape(feat2, backend))}dims, '
+ f'A1:{len(_get_shape(A1, backend))}dims, A2:{len(_get_shape(A2, backend))}dims!')
+
+ if not (_get_shape(feat1, backend)[0] == _get_shape(feat2, backend)[0] == _get_shape(A1, backend)[0] == _get_shape(A2, backend)[0])\
+ or not (_get_shape(feat1, backend)[1] == _get_shape(A1, backend)[1] == _get_shape(A1, backend)[2])\
+ or not (_get_shape(feat2, backend)[1] == _get_shape(A2, backend)[1] == _get_shape(A2, backend)[2])\
+ or not (_get_shape(feat1, backend)[2] == _get_shape(feat2, backend)[2]):
+ raise ValueError(
+ f'the input dimensions do not match. Got feat1:{_get_shape(feat1, backend)}, '
+ f'feat2:{_get_shape(feat2, backend)}, A1:{_get_shape(A1, backend)}, A2:{_get_shape(A2, backend)}!')
+ if n1 is not None: _check_data_type(n1, 'n1', backend)
+ if n2 is not None: _check_data_type(n2, 'n2', backend)
+
+ args = (feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3,
+ tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain)
+ try:
+ mod = importlib.import_module(f'pygmtools.{backend}_backend')
+ fn = mod.genn_astar
+ except (ModuleNotFoundError, AttributeError):
+ raise NotImplementedError(
+ NOT_IMPLEMENTED_MSG.format(backend)
+ )
+
+ result = fn(*args)
+ match_mat = _squeeze(result[0], 0, backend) if non_batched_input else result[0]
+ if return_network:
+ return match_mat, result[1]
+ else:
+ return match_mat
+
\ No newline at end of file
diff --git a/pygmtools/pytorch_astar_modules.py b/pygmtools/pytorch_astar_modules.py
new file mode 100644
index 00000000..1c2c8540
--- /dev/null
+++ b/pygmtools/pytorch_astar_modules.py
@@ -0,0 +1,381 @@
+import torch
+import pygmtools.utils
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+from torch import Tensor
+from typing import Optional, Tuple
+
+VERY_LARGE_INT = 65536
+
+###############################################################
+# GENN-A* Functions #
+###############################################################
+
+
+def default_parameter():
+ params = dict()
+ params['cuda'] = False
+ params['pretrain'] = False
+ params['channel'] = 36
+ params['filters_1'] = 64
+ params['filters_2'] = 32
+ params['filters_3'] = 16
+ params['tensor_neurons'] = 16
+ params['dropout'] = 0
+ params['astar_beam_width'] = 0
+ params['astar_trust_fact'] = 1
+ params['astar_no_pred'] = 0
+ params['use_net'] = True
+ return params
+
+
+def check_layer_parameter(params):
+ if params['pretrain'] == 'AIDS700nef':
+ if params['channel'] != 36:
+ return False
+ elif params['pretrain'] == 'LINUX':
+ if params['channel'] != 8:
+ return False
+ if params['filters_1'] != 64:
+ return False
+ if params['filters_2'] != 32:
+ return False
+ if params['filters_3'] != 16:
+ return False
+ if params['tensor_neurons'] != 16:
+ return False
+ return True
+
+
+def node_metric(node1, node2):
+
+ encoding = torch.sum(torch.abs(node1.unsqueeze(2) - node2.unsqueeze(1)), dim=-1)
+ non_zero = torch.nonzero(encoding)
+ for i in range(non_zero.shape[0]):
+ encoding[non_zero[i][0], non_zero[i][1], non_zero[i][2]] = 1
+ return encoding
+
+
+def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
+ if dim < 0:
+ dim = other.dim() + dim
+ if src.dim() == 1:
+ for _ in range(0, dim):
+ src = src.unsqueeze(0)
+ for _ in range(src.dim(), other.dim()):
+ src = src.unsqueeze(-1)
+ src = src.expand(other.size())
+ return src
+
+
+def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+ index = broadcast(index, src, dim)
+ if out is None:
+ size = list(src.size())
+ if dim_size is not None:
+ size[dim] = dim_size
+ elif index.numel() == 0:
+ size[dim] = 0
+ else:
+ size[dim] = int(index.max()) + 1
+ out = torch.zeros(size, dtype=src.dtype, device=src.device)
+ return out.scatter_add_(dim, index, src)
+ else:
+ return out.scatter_add_(dim, index, src)
+
+
+def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+ out = scatter_sum(src, index, dim, out, dim_size)
+ dim_size = out.size(dim)
+
+ index_dim = dim
+ if index_dim < 0:
+ index_dim = index_dim + src.dim()
+ if index.dim() <= index_dim:
+ index_dim = index.dim() - 1
+
+ ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
+ count = scatter_sum(ones, index, index_dim, None, dim_size)
+ count[count < 1] = 1
+ count = broadcast(count, out, dim)
+ if out.is_floating_point():
+ out.true_divide_(count)
+ else:
+ out.div_(count, rounding_mode='floor')
+ return out
+
+
+def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None,
+ fill_value: float = 0., max_num_nodes: Optional[int] = None,
+ batch_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
+ if batch is None and max_num_nodes is None:
+ mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)
+ return x.unsqueeze(0), mask
+
+ if batch is None:
+ batch = x.new_zeros(x.size(0), dtype=torch.long)
+
+ if batch_size is None:
+ batch_size = int(batch.max()) + 1
+
+ num_nodes = scatter_sum(batch.new_ones(x.size(0)), batch, dim=0,
+ dim_size=batch_size)
+ cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
+
+ if max_num_nodes is None:
+ max_num_nodes = int(num_nodes.max())
+
+ idx = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
+ idx = (idx - cum_nodes[batch]) + (batch * max_num_nodes)
+
+ size = [batch_size * max_num_nodes] + list(x.size())[1:]
+ out = x.new_full(size, fill_value)
+ out[idx] = x
+ out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])
+
+ mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,
+ device=x.device)
+ mask[idx] = 1
+ mask = mask.view(batch_size, max_num_nodes)
+
+ return out, mask
+
+
+def to_dense_adj(edge_index: Tensor,batch=None,edge_attr=None,max_num_nodes: Optional[int] = None) -> Tensor:
+ if batch is None:
+ num_nodes = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
+ batch = edge_index.new_zeros(num_nodes)
+
+ batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1
+ one = batch.new_ones(batch.size(0))
+ num_nodes = scatter_sum(one, batch, dim=0, dim_size=batch_size)
+ cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
+
+ idx0 = batch[edge_index[0]]
+ idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
+ idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
+
+ if max_num_nodes is None:
+ max_num_nodes = num_nodes.max().item()
+
+ elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes)
+ or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)):
+ mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
+ idx0 = idx0[mask]
+ idx1 = idx1[mask]
+ idx2 = idx2[mask]
+ edge_attr = None if edge_attr is None else edge_attr[mask]
+
+ if edge_attr is None:
+ edge_attr = torch.ones(idx0.numel(), device=edge_index.device)
+
+ size = [batch_size, max_num_nodes, max_num_nodes]
+ size += list(edge_attr.size())[1:]
+ adj = torch.zeros(size, dtype=edge_attr.dtype, device=edge_index.device)
+
+ flattened_size = batch_size * max_num_nodes * max_num_nodes
+ adj = adj.view([flattened_size] + list(adj.size())[3:])
+ idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2
+ adj = scatter_sum(edge_attr, idx, dim=0, dim_size=flattened_size)
+ adj = adj.view(size)
+
+ return adj
+
+
+###############################################################
+# GENN-A* Modules #
+###############################################################
+
+
+class GraphPair:
+ def __init__(self, x1: torch.Tensor, x2: torch.Tensor, adj1: torch.Tensor,
+ adj2: torch.Tensor, n1=None, n2=None):
+ self.g1 = Graphs(x1, adj1, n1)
+ self.g2 = Graphs(x2, adj2, n2)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}('g1' = {self.g1}, 'g2' = {self.g2})"
+
+ def to_dict(self):
+ data = dict()
+ data['g1'] = self.g1
+ data['g2'] = self.g2
+ return data
+
+
+class Graphs:
+ def __init__(self, x: torch.Tensor, adj: torch.Tensor, nodes_num=None):
+ assert len(x.shape) == len(adj.shape)
+ if len(adj.shape) == 2:
+ adj = adj.unsqueeze(dim=0)
+ x = x.unsqueeze(dim=0)
+ assert x.shape[0] == adj.shape[0]
+ assert x.shape[1] == adj.shape[1]
+ self.x = x
+ self.adj = adj
+ self.num_graphs = adj.shape[0]
+ if nodes_num is not None:
+ self.nodes_num = nodes_num
+ else:
+ self.nodes_num = torch.tensor([x.shape[1]]*x.shape[0])
+ if self.x.shape[0] == 1:
+ if self.x.shape[1] != nodes_num:
+ self.x = self.x[:, :nodes_num, :]
+ self.adj = self.adj[:, :nodes_num, :nodes_num]
+ self.edge_index = None
+ self.edge_weight = None
+ self.batch = None
+ self.graph_process()
+
+ def graph_process(self):
+ edge_index, edge_weight, _ = pygmtools.utils.dense_to_sparse(self.adj)
+ self.edge_index = torch.cat([edge_index[:, :, 0].unsqueeze(dim=1),
+ edge_index[:, :, 1].unsqueeze(dim=1)], dim=1)
+ self.edge_weight = edge_weight.view(-1)
+ if self.nodes_num.shape == torch.Size([]):
+ batch = torch.tensor([0] * self.nodes_num)
+ else:
+ for i in range(len(self.nodes_num)):
+ if i == 0:
+ batch = torch.tensor([i] * self.nodes_num[i])
+ else:
+ cur_batch = torch.tensor([i] * self.nodes_num[i])
+ batch = torch.cat([batch, cur_batch])
+ self.batch = batch
+
+ def __repr__(self):
+ message = "x = {}, adj = {}".format(list(self.x.shape), list(self.adj.shape))
+ message += " edge_index = {}, edge_weight = {}".format(list(self.edge_index.shape), list(self.edge_weight.shape))
+ message += " nodes_num = {}, num_graphs = {})".format(self.nodes_num.shape, self.num_graphs)
+ self.message = message
+ return f"{self.__class__.__name__}({self.message})"
+
+
+class AttentionModule(torch.nn.Module):
+ """
+ SimGNN Attention Module to make a pass on graph.
+ """
+ def __init__(self, args):
+ """
+ :param args: Arguments object.
+ """
+ super(AttentionModule, self).__init__()
+ self.args = args
+ self.setup_weights()
+ self.init_parameters()
+
+ def setup_weights(self):
+ """
+ Defining weights.
+ """
+ self.weight_matrix = torch.nn.Parameter(torch.Tensor(self.args['filters_3'], self.args['filters_3']))
+
+ def init_parameters(self):
+ """
+ Initializing weights.
+ """
+ torch.nn.init.xavier_uniform_(self.weight_matrix)
+
+ def forward(self, x, batch, size=None):
+ """
+ Making a forward propagation pass to create a graph level representation.
+ :param x: Result of the GNN.
+ :param batch: Batch vector, which assigns each node to a specific example
+ :return representation: A graph level representation matrix.
+ """
+ size = batch[-1].item() + 1 if size is None else size
+ mean = scatter_mean(x, batch, dim=0, dim_size=size)
+ transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))
+
+ coefs = torch.sigmoid((x * transformed_global[batch] * 10).sum(dim=1))
+ weighted = coefs.unsqueeze(-1) * x
+
+ return scatter_sum(weighted, batch, dim=0, dim_size=size)
+
+ def get_coefs(self, x):
+ mean = x.mean(dim=0)
+ transformed_global = torch.tanh(torch.matmul(mean, self.weight_matrix))
+
+ return torch.sigmoid(torch.matmul(x, transformed_global))
+
+
+class TensorNetworkModule(torch.nn.Module):
+ """
+ SimGNN Tensor Network module to calculate similarity vector.
+ """
+ def __init__(self,args):
+ """
+ :param args: Arguments object.
+ """
+ super(TensorNetworkModule, self).__init__()
+ self.args = args
+ self.setup_weights()
+ self.init_parameters()
+
+ def setup_weights(self):
+ """
+ Defining weights.
+ """
+ self.weight_matrix = torch.nn.Parameter(torch.Tensor(self.args['filters_3'], self.args['filters_3'], self.args['tensor_neurons']))
+ self.weight_matrix_block = torch.nn.Parameter(torch.Tensor(self.args['tensor_neurons'], 2*self.args['filters_3']))
+ self.bias = torch.nn.Parameter(torch.Tensor(self.args['tensor_neurons'], 1))
+
+ def init_parameters(self):
+ """
+ Initializing weights.
+ """
+ torch.nn.init.xavier_uniform_(self.weight_matrix)
+ torch.nn.init.xavier_uniform_(self.weight_matrix_block)
+ torch.nn.init.xavier_uniform_(self.bias)
+
+ def forward(self, embedding_1, embedding_2):
+ """
+ Making a forward propagation pass to create a similarity vector.
+ :param embedding_1: Result of the 1st embedding after attention.
+ :param embedding_2: Result of the 2nd embedding after attention.
+ :return scores: A similarity score vector.
+ """
+ batch_size = len(embedding_1)
+ scoring = torch.matmul(embedding_1, self.weight_matrix.view(self.args['filters_3'],-1))
+ scoring = scoring.view(batch_size, self.args['filters_3'], -1).permute([0, 2, 1])
+ scoring = torch.matmul(scoring, embedding_2.view(batch_size, self.args['filters_3'], 1)).view(batch_size, -1)
+ combined_representation = torch.cat((embedding_1, embedding_2), 1)
+ block_scoring = torch.t(torch.mm(self.weight_matrix_block, torch.t(combined_representation)))
+ scores = F.relu(scoring + block_scoring + self.bias.view(-1))
+ return scores
+
+
+class GCNConv(nn.Module):
+
+ def __init__(self, in_features: int, out_features: int):
+ super(GCNConv, self).__init__()
+ self.num_inputs = in_features
+ self.num_outputs = out_features
+ self.weight = Parameter(torch.empty((in_features,out_features)))
+ self.bias = Parameter(torch.empty(out_features))
+
+ def forward(self, A: Tensor, x: Tensor, norm: bool=True) -> Tensor:
+ r"""
+ Forward computation of graph convolution network.
+
+ :param A: :math:`(b\times n\times n)` {0,1} adjacency matrix. :math:`b`: batch size, :math:`n`: number of nodes
+ :param x: :math:`(b\times n\times d)` input node embedding. :math:`d`: feature dimension
+ :param norm: normalize connectivity matrix or not
+ :return: :math:`(b\times n\times d^\prime)` new node embedding
+ """
+ x = torch.mm(x,self.weight) + self.bias
+ D = torch.zeros_like(A)
+
+ for i in range(A.shape[0]):
+ A[i,i] = 1
+ D[i,i] = torch.pow(torch.sum(A[i]),exponent=-0.5)
+ A = torch.mm(torch.mm(D,A),D)
+ return torch.mm(A,x)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(in_features={self.num_inputs}, out_features={self.num_outputs})"
diff --git a/pygmtools/pytorch_backend.py b/pygmtools/pytorch_backend.py
index 2d3c9eca..b43a52f2 100644
--- a/pygmtools/pytorch_backend.py
+++ b/pygmtools/pytorch_backend.py
@@ -17,6 +17,10 @@
import os
import pygmtools.utils
+from .pytorch_astar_modules import GCNConv, AttentionModule, TensorNetworkModule, GraphPair, \
+ VERY_LARGE_INT, to_dense_adj, to_dense_batch, default_parameter, check_layer_parameter, node_metric
+from torch import Tensor
+from pygmtools.a_star import a_star
#############################################
# Linear Assignment Problem Solvers #
@@ -25,9 +29,9 @@
from pygmtools.numpy_backend import _hung_kernel
-def hungarian(s: Tensor, n1: Tensor=None, n2: Tensor=None,
- unmatch1: Tensor=None, unmatch2: Tensor=None,
- nproc: int=1) -> Tensor:
+def hungarian(s: Tensor, n1: Tensor = None, n2: Tensor = None,
+ unmatch1: Tensor = None, unmatch2: Tensor = None,
+ nproc: int = 1) -> Tensor:
"""
Pytorch implementation of Hungarian algorithm
"""
@@ -57,16 +61,17 @@ def hungarian(s: Tensor, n1: Tensor=None, n2: Tensor=None,
mapresult = pool.starmap_async(_hung_kernel, zip(perm_mat, n1, n2, unmatch1, unmatch2))
perm_mat = np.stack(mapresult.get())
else:
- perm_mat = np.stack([_hung_kernel(perm_mat[b], n1[b], n2[b], unmatch1[b], unmatch2[b]) for b in range(batch_num)])
+ perm_mat = np.stack(
+ [_hung_kernel(perm_mat[b], n1[b], n2[b], unmatch1[b], unmatch2[b]) for b in range(batch_num)])
perm_mat = torch.from_numpy(perm_mat).to(device)
return perm_mat
-def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None,
- unmatchrows: Tensor=None, unmatchcols: Tensor=None,
- dummy_row: bool=False, max_iter: int=10, tau: float=1., batched_operation: bool=False) -> Tensor:
+def sinkhorn(s: Tensor, nrows: Tensor = None, ncols: Tensor = None,
+ unmatchrows: Tensor = None, unmatchcols: Tensor = None,
+ dummy_row: bool = False, max_iter: int = 10, tau: float = 1., batched_operation: bool = False) -> Tensor:
"""
Pytorch implementation of Sinkhorn algorithm
"""
@@ -91,7 +96,7 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None,
s_t = s.transpose(1, 2)
s_t = torch.cat((
s_t[:, :s.shape[1], :],
- torch.full((batch_size, s.shape[1], s.shape[2]-s.shape[1]), -float('inf'), device=s.device)), dim=2)
+ torch.full((batch_size, s.shape[1], s.shape[2] - s.shape[1]), -float('inf'), device=s.device)), dim=2)
s = torch.where(transposed_batch.view(batch_size, 1, 1), s_t, s)
new_nrows = torch.where(transposed_batch, ncols, nrows)
@@ -103,8 +108,9 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None,
unmatchrows_pad = torch.cat((
unmatchrows,
torch.full((batch_size, unmatchcols.shape[1] - unmatchrows.shape[1]), -float('inf'), device=s.device)),
- dim=1)
- new_unmatchrows = torch.where(transposed_batch.view(batch_size, 1), unmatchcols, unmatchrows_pad)[:, :unmatchrows.shape[1]]
+ dim=1)
+ new_unmatchrows = torch.where(transposed_batch.view(batch_size, 1), unmatchcols, unmatchrows_pad)[:,
+ :unmatchrows.shape[1]]
new_unmatchcols = torch.where(transposed_batch.view(batch_size, 1), unmatchrows_pad, unmatchcols)
unmatchrows = new_unmatchrows
unmatchcols = new_unmatchcols
@@ -121,15 +127,19 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None,
dummy_shape[1] = log_s.shape[2] - log_s.shape[1]
ori_nrows = nrows
nrows = ncols.clone()
- log_s = torch.cat((log_s, torch.full(dummy_shape, -float('inf'), device=log_s.device, dtype=log_s.dtype)), dim=1)
+ log_s = torch.cat((log_s, torch.full(dummy_shape, -float('inf'), device=log_s.device, dtype=log_s.dtype)),
+ dim=1)
if unmatchrows is not None:
- unmatchrows = torch.cat((unmatchrows, torch.full((dummy_shape[0], dummy_shape[1]), -float('inf'), device=log_s.device, dtype=log_s.dtype)), dim=1)
+ unmatchrows = torch.cat((unmatchrows,
+ torch.full((dummy_shape[0], dummy_shape[1]), -float('inf'), device=log_s.device,
+ dtype=log_s.dtype)), dim=1)
for b in range(batch_size):
log_s[b, ori_nrows[b]:nrows[b], :ncols[b]] = -100
# assign the unmatch weights
if unmatchrows is not None and unmatchcols is not None:
- new_log_s = torch.full((log_s.shape[0], log_s.shape[1]+1, log_s.shape[2]+1), -float('inf'), device=log_s.device, dtype=log_s.dtype)
+ new_log_s = torch.full((log_s.shape[0], log_s.shape[1] + 1, log_s.shape[2] + 1), -float('inf'),
+ device=log_s.device, dtype=log_s.dtype)
new_log_s[:, :-1, :-1] = log_s
log_s = new_log_s
for b in range(batch_size):
@@ -161,7 +171,8 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None,
ret_log_s = log_s
else:
- ret_log_s = torch.full((batch_size, log_s.shape[1], log_s.shape[2]), -float('inf'), device=log_s.device, dtype=log_s.dtype)
+ ret_log_s = torch.full((batch_size, log_s.shape[1], log_s.shape[2]), -float('inf'), device=log_s.device,
+ dtype=log_s.dtype)
for b in range(batch_size):
row_slice = slice(0, nrows[b])
@@ -198,7 +209,8 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None,
s_t = ret_log_s.transpose(1, 2)
s_t = torch.cat((
s_t[:, :ret_log_s.shape[1], :],
- torch.full((batch_size, ret_log_s.shape[1], ret_log_s.shape[2]-ret_log_s.shape[1]), -float('inf'), device=log_s.device)), dim=2)
+ torch.full((batch_size, ret_log_s.shape[1], ret_log_s.shape[2] - ret_log_s.shape[1]), -float('inf'),
+ device=log_s.device)), dim=2)
ret_log_s = torch.where(transposed_batch.view(batch_size, 1, 1), s_t, ret_log_s)
if transposed:
@@ -410,7 +422,7 @@ def _comp_aff_score(x, k):
X1 = X.reshape(m, 1, m, n, n).repeat(1, m, 1, 1, 1).reshape(-1, n, n) # X1[i,j,k] = X[i,k]
X2 = X.reshape(1, m, m, n, n).repeat(m, 1, 1, 1, 1).transpose(1, 2).reshape(-1, n, n) # X2[i,j,k] = X[k,j]
- X_combo = torch.bmm(X1, X2).reshape(m, m, m, n, n) # X_combo[i,j,k] = X[i, k] * X[k, j]
+ X_combo = torch.bmm(X1, X2).reshape(m, m, m, n, n) # X_combo[i,j,k] = X[i, k] * X[k, j]
aff_ori = (_comp_aff_score(X.reshape(-1, n, n), K.reshape(-1, n * n, n * n)) / norm).reshape(m, m)
pair_con = _get_batch_pc_opt(X)
@@ -433,7 +445,8 @@ def _comp_aff_score(x, k):
assert torch.all(score_combo + 1e-4 >= score_ori), torch.min(score_combo - score_ori)
X_upt = X_combo[mask1, mask2, idx, :, :]
- X = X_upt * X_mask + X_upt.transpose(0, 1).transpose(2, 3) * X_mask.transpose(0, 1) + X * (1 - X_mask - X_mask.transpose(0, 1))
+ X = X_upt * X_mask + X_upt.transpose(0, 1).transpose(2, 3) * X_mask.transpose(0, 1) + X * (
+ 1 - X_mask - X_mask.transpose(0, 1))
assert torch.all(X.transpose(0, 1).transpose(2, 3) == X)
return X
@@ -597,7 +610,7 @@ def gamgm(
sk_iter, max_iter, quad_weight,
converge_thresh, outlier_thresh, bb_smooth,
verbose,
- cluster_M=None, projector='sinkhorn', hung_iter=True # these arguments are reserved for clustering
+ cluster_M=None, projector='sinkhorn', hung_iter=True # these arguments are reserved for clustering
):
"""
Pytorch implementation of Graduated Assignment for Multi-Graph Matching (with compatibility for 2GM and clustering)
@@ -661,6 +674,7 @@ class GAMGMTorchFunc(torch.autograd.Function):
"""
Torch wrapper to support forward and backward pass (by black-box differentiation)
"""
+
@staticmethod
def forward(ctx, bb_smooth, supA, supW, ns, n_indices, n_univ, num_graphs, U0, *args):
# save parameters
@@ -688,7 +702,8 @@ def backward(ctx, dU):
end_x = n_indices[i]
start_y = n_indices[j] - ns[j]
end_y = n_indices[j]
- supW[start_x:end_x, start_y:end_y] += bb_smooth * torch.mm(dU[start_x:end_x], dU[start_y:end_y].transpose(0, 1))
+ supW[start_x:end_x, start_y:end_y] += bb_smooth * torch.mm(dU[start_x:end_x],
+ dU[start_y:end_y].transpose(0, 1))
U_prime = gamgm_real(supA, supW, ns, n_indices, n_univ, num_graphs, U0, *args)
@@ -712,8 +727,8 @@ def gamgm_real(
sk_iter, max_iter, quad_weight,
converge_thresh, outlier_thresh,
verbose,
- cluster_M, projector, hung_iter # these arguments are reserved for clustering
- ):
+ cluster_M, projector, hung_iter # these arguments are reserved for clustering
+):
"""
The real forward function of GAMGM
"""
@@ -736,7 +751,7 @@ def gamgm_real(
else:
print_str = 'hungarian'
print(print_str + f' #iter={i}/{max_iter} '
- f'quad score: {(quad * U).sum():.3e}, unary score: {(unary * U).sum():.3e}')
+ f'quad score: {(quad * U).sum():.3e}, unary score: {(unary * U).sum():.3e}')
V = (quad + unary) / num_graphs
U_list = []
@@ -813,7 +828,7 @@ def gamgm_real(
if verbose: print('-' * 20)
- if i == max_iter - 1: # not converged
+ if i == max_iter - 1: # not converged
if hung_iter:
pass
else:
@@ -837,6 +852,329 @@ def gamgm_real(
return U
+astar_pretrain_path = {
+ 'AIDS700nef': ('https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/pytorch_backend/best_genn_AIDS700nef_gcn_astar.pt',
+ 'b2516aea4c8d730704a48653a5ca94ba'),
+ 'LINUX': ('https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/pytorch_backend/best_genn_LINUX_gcn_astar.pt',
+ 'fd3b2a8dfa3edb20607da2e2b96d2e96'),
+}
+
+
+class GENN(torch.nn.Module):
+ def __init__(self, args):
+ """
+ :param args: Arguments object.
+ :param number_of_labels: Number of node labels.
+ """
+ super(GENN, self).__init__()
+ self.args = args
+ if self.args['use_net']:
+ self.number_labels = self.args['channel']
+ self.setup_layers()
+
+ self.reset_cache()
+
+ def reset_cache(self):
+ self.gnn_1_cache = dict()
+ self.gnn_2_cache = dict()
+ self.heuristic_cache = dict()
+
+ def setup_layers(self):
+ """
+ Creating the layers.
+ """
+ self.feature_count = self.args['tensor_neurons']
+ self.convolution_1 = GCNConv(self.number_labels, self.args['filters_1'])
+ self.convolution_2 = GCNConv(self.args['filters_1'], self.args['filters_2'])
+ self.convolution_3 = GCNConv(self.args['filters_2'], self.args['filters_3'])
+ self.attention = AttentionModule(self.args)
+ self.tensor_network = TensorNetworkModule(self.args)
+ self.scoring_layer = torch.nn.Sequential(
+ torch.nn.Linear(self.feature_count, 16),
+ torch.nn.ReLU(),
+ torch.nn.Linear(16, 1),
+ torch.nn.Sigmoid()
+ )
+
+ def convolutional_pass(self, edge_index, x, edge_weight=None):
+ """
+ Making convolutional pass.
+ :param edge_index: Edge indices.
+ :param x: Feature matrix.
+ :param edge_weight: Edge weights.
+ :return features: Abstract feature matrix.
+ """
+
+ features = self.convolution_1(edge_index, x, edge_weight)
+ features = F.relu(features)
+ features = F.dropout(features, p=self.args['dropout'], training=self.training)
+ features = self.convolution_2(edge_index, features, edge_weight)
+ features = F.relu(features)
+ features = F.dropout(features, p=self.args['dropout'], training=self.training)
+ features = self.convolution_3(edge_index, features, edge_weight)
+ return features
+
+ def forward(self, data: GraphPair):
+ """
+ Forward pass with graphs.
+ :param data: Data dictionary.
+ :return score: Similarity score.
+ """
+ num = data.g1.num_graphs
+ max_nodes_num_1 = torch.max(data.g1.nodes_num) + 1
+ max_nodes_num_2 = torch.max(data.g2.nodes_num) + 1
+ x_pred = torch.zeros(num, max_nodes_num_1, max_nodes_num_2)
+ for i in range(num):
+ cur_data = GraphPair(data.g1.x[i], data.g2.x[i], data.g1.adj[i], data.g2.adj[i],
+ data.g1.nodes_num[i], data.g2.nodes_num[i])
+ num_nodes_1 = data.g1.nodes_num[i] + 1
+ num_nodes_2 = data.g2.nodes_num[i] + 1
+ x_pred[i][:num_nodes_1, :num_nodes_2] = self._a_star(cur_data)
+ return x_pred[:, :-1, :-1]
+
+ def _a_star(self, data: GraphPair):
+
+ if self.args['cuda']:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ device = "cpu"
+ edge_index_1 = data.g1.edge_index.squeeze()
+ edge_index_2 = data.g2.edge_index.squeeze()
+ edge_attr_1 = data.g1.edge_weight
+ edge_attr_2 = data.g2.edge_weight
+ node_1 = data.g1.x.squeeze()
+ node_2 = data.g2.x.squeeze()
+ batch_1 = data.g1.batch
+ batch_2 = data.g2.batch
+ batch_num = data.g1.num_graphs
+
+ ns_1 = torch.bincount(data.g1.batch)
+ ns_2 = torch.bincount(data.g2.batch)
+
+ adj_1 = to_dense_adj(edge_index_1, batch=batch_1, edge_attr=edge_attr_1)
+
+ dummy_adj_1 = torch.zeros(adj_1.shape[0], adj_1.shape[1] + 1, adj_1.shape[2] + 1, device=device)
+ dummy_adj_1[:, :-1, :-1] = adj_1
+ adj_2 = to_dense_adj(edge_index_2, batch=batch_2, edge_attr=edge_attr_2)
+ dummy_adj_2 = torch.zeros(adj_2.shape[0], adj_2.shape[1] + 1, adj_2.shape[2] + 1, device=device)
+ dummy_adj_2[:, :-1, :-1] = adj_2
+
+ node_1, _ = to_dense_batch(node_1, batch=batch_1)
+ node_2, _ = to_dense_batch(node_2, batch=batch_2)
+
+ dummy_node_1 = torch.zeros(adj_1.shape[0], node_1.shape[1] + 1, node_1.shape[-1], device=device)
+ dummy_node_1[:, :-1, :] = node_1
+ dummy_node_2 = torch.zeros(adj_2.shape[0], node_2.shape[1] + 1, node_2.shape[-1], device=device)
+ dummy_node_2[:, :-1, :] = node_2
+ k_diag = node_metric(dummy_node_1, dummy_node_2)
+
+ mask_1 = torch.zeros_like(dummy_adj_1)
+ mask_2 = torch.zeros_like(dummy_adj_2)
+ for b in range(batch_num):
+ mask_1[b, :ns_1[b] + 1, :ns_1[b] + 1] = 1
+ mask_1[b, :ns_1[b], :ns_1[b]] -= torch.eye(ns_1[b], device=mask_1.device)
+ mask_2[b, :ns_2[b] + 1, :ns_2[b] + 1] = 1
+ mask_2[b, :ns_2[b], :ns_2[b]] -= torch.eye(ns_2[b], device=mask_2.device)
+
+ a1 = dummy_adj_1.reshape(batch_num, -1, 1)
+ a2 = dummy_adj_2.reshape(batch_num, 1, -1)
+ m1 = mask_1.reshape(batch_num, -1, 1)
+ m2 = mask_2.reshape(batch_num, 1, -1)
+ k = torch.abs(a1 - a2) * torch.bmm(m1, m2)
+ k[torch.logical_not(torch.bmm(m1, m2).to(dtype=torch.bool))] = VERY_LARGE_INT
+ k = k.reshape(batch_num, dummy_adj_1.shape[1], dummy_adj_1.shape[2], dummy_adj_2.shape[1], dummy_adj_2.shape[2])
+ k = k.permute([0, 1, 3, 2, 4])
+ k = k.reshape(batch_num, dummy_adj_1.shape[1] * dummy_adj_2.shape[1],
+ dummy_adj_1.shape[2] * dummy_adj_2.shape[2])
+ k = k / 2
+
+ for b in range(batch_num):
+ k_diag_view = torch.diagonal(k[b])
+ k_diag_view[:] = k_diag[b].reshape(-1)
+
+ self.reset_cache()
+
+ x_pred, _ = a_star(
+ data, k, ns_1.cpu().numpy(), ns_2.cpu().numpy(),
+ self.net_prediction_cache,
+ self.heuristic_prediction_hun,
+ net_pred=self.args['use_net'],
+ beam_width=self.args['astar_beam_width'],
+ trust_fact=self.args['astar_trust_fact'],
+ no_pred_size=self.args['astar_no_pred'],
+ )
+
+ return x_pred
+
+ def net_prediction_cache(self, data: GraphPair, partial_pmat=None, return_ged_norm=False):
+ """
+ Forward pass with graphs.
+ :param data: Data class.
+ :param partial_pmat: Matched matrix.
+ :param return_ged_norm: Whether to return to Normal Graph Edit Distance.
+ :return score: Similarity score.
+ """
+ features_1 = data.g1.x.squeeze()
+ features_2 = data.g2.x.squeeze()
+ batch_1 = data.g1.batch
+ batch_2 = data.g2.batch
+ adj1 = data.g1.adj.squeeze()
+ adj2 = data.g2.adj.squeeze()
+
+ if 'gnn_feat' not in self.gnn_1_cache:
+ abstract_features_1 = self.convolutional_pass(adj1, features_1)
+ self.gnn_1_cache['gnn_feat'] = abstract_features_1
+ else:
+ abstract_features_1 = self.gnn_1_cache['gnn_feat']
+ if 'gnn_feat' not in self.gnn_2_cache:
+ abstract_features_2 = self.convolutional_pass(adj2, features_2)
+ self.gnn_2_cache['gnn_feat'] = abstract_features_2
+ else:
+ abstract_features_2 = self.gnn_2_cache['gnn_feat']
+
+ graph_1_mask = torch.ones_like(batch_1)
+ graph_2_mask = torch.ones_like(batch_2)
+ graph_1_matched = partial_pmat.sum(dim=-1).to(dtype=torch.bool)[:graph_1_mask.shape[0]]
+ graph_2_matched = partial_pmat.sum(dim=-2).to(dtype=torch.bool)[:graph_2_mask.shape[0]]
+ graph_1_mask = torch.logical_not(graph_1_matched)
+ graph_2_mask = torch.logical_not(graph_2_matched)
+ abstract_features_1 = abstract_features_1[graph_1_mask]
+ abstract_features_2 = abstract_features_2[graph_2_mask]
+ batch_1 = batch_1[graph_1_mask]
+ batch_2 = batch_2[graph_2_mask]
+ pooled_features_1 = self.attention(abstract_features_1, batch_1)
+ pooled_features_2 = self.attention(abstract_features_2, batch_2)
+ scores = self.tensor_network(pooled_features_1, pooled_features_2)
+ score = self.scoring_layer(scores).view(-1)
+
+ if return_ged_norm:
+ return score
+ else:
+ ged = - torch.log(score) * (batch_1.shape[0] + batch_2.shape[0]) / 2
+ return ged
+
+ def heuristic_prediction_hun(self, k: torch.Tensor, n1, n2, partial_pmat):
+ k_prime = k.reshape(-1, n1 + 1, n2 + 1)
+ node_costs = torch.empty(k_prime.shape[0])
+ for i in range(k_prime.shape[0]):
+ _, node_costs[i] = hungarian_ged(k_prime[i], n1, n2)
+ node_cost_mat = node_costs.reshape(n1 + 1, n2 + 1)
+ self.heuristic_cache['node_cost'] = node_cost_mat
+
+ graph_1_mask = ~partial_pmat.sum(dim=-1).to(dtype=torch.bool)
+ graph_2_mask = ~partial_pmat.sum(dim=-2).to(dtype=torch.bool)
+ graph_1_mask[-1] = 1
+ graph_2_mask[-1] = 1
+ node_cost_mat = node_cost_mat[graph_1_mask, :]
+ node_cost_mat = node_cost_mat[:, graph_2_mask]
+
+ _, ged = hungarian_ged(node_cost_mat, torch.sum(graph_1_mask[:-1]), torch.sum(graph_2_mask[:-1]))
+
+ return ged
+
+
+def hungarian_ged(node_cost_mat: torch.Tensor, n1, n2):
+ assert node_cost_mat.shape[-2] == n1 + 1
+ assert node_cost_mat.shape[-1] == n2 + 1
+ device = node_cost_mat.device
+ upper_left = node_cost_mat[:n1, :n2]
+ upper_right = torch.full((n1, n1), float('inf'), device=device)
+ torch.diagonal(upper_right)[:] = node_cost_mat[:-1, -1]
+ lower_left = torch.full((n2, n2), float('inf'), device=device)
+ torch.diagonal(lower_left)[:] = node_cost_mat[-1, :-1]
+ lower_right = torch.zeros((n2, n1), device=device)
+ large_cost_mat = torch.cat((torch.cat((upper_left, upper_right), dim=1),
+ torch.cat((lower_left, lower_right), dim=1)), dim=0)
+
+ large_pred_x = hungarian(-large_cost_mat.unsqueeze(dim=0)).squeeze()
+ pred_x = torch.zeros_like(node_cost_mat)
+ pred_x[:n1, :n2] = large_pred_x[:n1, :n2]
+ pred_x[:-1, -1] = torch.sum(large_pred_x[:n1, n2:], dim=1)
+ pred_x[-1, :-1] = torch.sum(large_pred_x[n1:, :n2], dim=0)
+
+ ged_lower_bound = torch.sum(pred_x * node_cost_mat)
+ return pred_x, ged_lower_bound
+
+
+def astar(feat1, feat2, A1, A2, n1, n2, channel, dropout, beam_width, trust_fact, no_pred_size):
+ """
+ Pytorch implementation of ASTAR
+ """
+ return astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, dropout=dropout, beam_width=beam_width,
+ filters_1=64, filters_2=32, filters_3=16, tensor_neurons=16, trust_fact=trust_fact,
+ no_pred_size=no_pred_size, pretrain=False, network=None, use_net=False)
+
+
+def astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3,
+ tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain, use_net):
+ """
+ The true implementation of astar and genn_astar functions
+ """
+ if feat1 is None:
+ forward_pass = False
+ device = torch.device('cpu')
+ else:
+ assert feat1.shape[-1] == feat2.shape[-1], 'The feature dimensions of feat1 and feat2 must be consistent'
+ forward_pass = True
+ device = feat1.device
+
+ if network is None:
+ args = default_parameter()
+ if forward_pass:
+ if channel is None:
+ args['channel'] = feat1.shape[-1]
+ else:
+ assert feat1.shape[-1] == channel, 'the channel {} must match the feature dimension of feat1\n'.format(
+ channel)
+ args['channel'] = channel
+ else:
+ if channel is None:
+ args['channel'] = 8 if pretrain == "LINUX" else 36
+ else:
+ args['channel'] = channel
+
+ args['filters_1'] = filters_1
+ args['filters_2'] = filters_2
+ args['filters_3'] = filters_3
+ args['tensor_neurons'] = tensor_neurons
+ args['dropout'] = dropout
+ args['astar_beam_width'] = beam_width
+ args['astar_trust_fact'] = trust_fact
+ args['astar_no_pred'] = no_pred_size
+ args['pretrain'] = pretrain
+ args['use_net'] = use_net
+
+ network = GENN(args)
+
+ network = network.to(device)
+ if pretrain and args['use_net']:
+ if pretrain in astar_pretrain_path:
+ url, md5 = astar_pretrain_path[pretrain]
+ filename = pygmtools.utils.download(f'best_genn_{pretrain}_gcn_astar.pt', url, md5)
+ if check_layer_parameter(args):
+ _load_model(network, filename, device)
+ else:
+ _load_part_model(network, filename, device)
+ message = 'Warning: Pretrain {} does not support the parameters you entered, '.format(pretrain)
+ if args['pretrain'] == 'AIDS700nef':
+ message += "Supported parameters: ( channel:36, filters:(64,32,16) ), "
+ elif args['pretrain'] == 'LINUX':
+ message += "Supported parameters: ( channel:8, filters:(64,32,16) ), "
+ message += 'Input parameters: ( channel:{}, filters:({},{},{}) )'.format(args['channel'], \
+ args['filters_1'], args['filters_2'], args['filters_3'])
+ print(message)
+ else:
+ raise ValueError(f'Unknown pretrain tag. Available tags: {astar_pretrain_path.keys()}')
+
+ if forward_pass:
+ assert A1.shape[0] == A2.shape[0]
+ data = GraphPair(feat1, feat2, A1, A2, n1, n2)
+ result = network(data)
+ else:
+ result = None
+ return result, network
+
+
############################################
# Neural Network Solvers #
############################################
@@ -848,6 +1186,7 @@ class PCA_GM_Net(torch.nn.Module):
"""
Pytorch implementation of PCA-GM and IPCA-GM network
"""
+
def __init__(self, in_channel, hidden_channel, out_channel, num_layers, cross_iter_num=-1):
super(PCA_GM_Net, self).__init__()
self.gnn_layer = num_layers
@@ -863,8 +1202,7 @@ def __init__(self, in_channel, hidden_channel, out_channel, num_layers, cross_it
if i == self.gnn_layer - 2: # only the second last layer will have cross-graph module
self.add_module('cross_graph_{}'.format(i), torch.nn.Linear(hidden_channel * 2, hidden_channel))
if cross_iter_num <= 0:
- self.add_module('affinity_{}'.format(i), WeightedInnerProdAffinity(hidden_channel))
-
+ self.add_module('affinity_{}'.format(i), WeightedInnerProdAffinity(hidden_channel))
def forward(self, feat1, feat2, A1, A2, n1, n2, cross_iter_num, sk_max_iter, sk_tau):
_sinkhorn_func = functools.partial(sinkhorn,
@@ -972,8 +1310,8 @@ def pca_gm(feat1, feat2, A1, A2, n1, n2,
def ipca_gm(feat1, feat2, A1, A2, n1, n2,
- in_channel, hidden_channel, out_channel, num_layers, cross_iter, sk_max_iter, sk_tau,
- network, pretrain):
+ in_channel, hidden_channel, out_channel, num_layers, cross_iter, sk_max_iter, sk_tau,
+ network, pretrain):
"""
Pytorch implementation of IPCA-GM
"""
@@ -1009,6 +1347,7 @@ class CIE_Net(torch.nn.Module):
"""
Pytorch implementation of CIE graph matching network
"""
+
def __init__(self, in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers):
super(CIE_Net, self).__init__()
self.gnn_layer = num_layers
@@ -1099,6 +1438,7 @@ class NGM_Net(torch.nn.Module):
"""
Pytorch implementation of NGM network
"""
+
def __init__(self, gnn_channels, sk_emb):
super(NGM_Net, self).__init__()
self.gnn_layer = len(gnn_channels)
@@ -1173,6 +1513,15 @@ def ngm(K, n1, n2, n1max, n2max, x0, gnn_channels, sk_emb, sk_max_iter, sk_tau,
return result, network
+def genn_astar(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3,
+ tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain):
+ """
+ Pytorch implementation of GENN-ASTAR
+ """
+ return astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3,
+ tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain, use_net=True)
+
+
#############################################
# Utils Functions #
#############################################
@@ -1222,7 +1571,8 @@ def build_batch(input, return_ori_dim=False):
padded_ts.append(torch.nn.functional.pad(t, pad_pattern, 'constant', 0))
if return_ori_dim:
- return torch.stack(padded_ts, dim=0), tuple([torch.tensor(_, dtype=torch.int64, device=device) for _ in ori_shape])
+ return torch.stack(padded_ts, dim=0), tuple(
+ [torch.tensor(_, dtype=torch.int64, device=device) for _ in ori_shape])
else:
return torch.stack(padded_ts, dim=0)
@@ -1361,8 +1711,10 @@ def _aff_mat_from_node_edge_aff(node_aff: Tensor, edge_aff: Tensor, connectivity
if edge_aff is not None:
conn1 = connectivity1[b][:ne1[b]]
conn2 = connectivity2[b][:ne2[b]]
- edge_indices = torch.cat([conn1.repeat_interleave(ne2[b], dim=0), conn2.repeat(ne1[b], 1)], dim=1) # indices: start_g1, end_g1, start_g2, end_g2
- edge_indices = (edge_indices[:, 2], edge_indices[:, 0], edge_indices[:, 3], edge_indices[:, 1]) # indices: start_g2, start_g1, end_g2, end_g1
+ edge_indices = torch.cat([conn1.repeat_interleave(ne2[b], dim=0), conn2.repeat(ne1[b], 1)],
+ dim=1) # indices: start_g1, end_g1, start_g2, end_g2
+ edge_indices = (edge_indices[:, 2], edge_indices[:, 0], edge_indices[:, 3],
+ edge_indices[:, 1]) # indices: start_g2, start_g1, end_g2, end_g1
k[edge_indices] = edge_aff[b, :ne1[b], :ne2[b]].reshape(-1)
k = k.reshape(n2max * n1max, n2max * n1max)
# node-wise affinity
@@ -1451,3 +1803,25 @@ def _load_model(model, path, device, strict=True):
if len(missing_keys) > 0:
print('Warning: Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
+
+
+def _load_part_model(model, path, device):
+ """
+ Load PyTorch model from a given path. This function is used for some parameters' size mismatching
+ """
+ if isinstance(model, torch.nn.DataParallel):
+ module = model.module
+ else:
+ module = model
+ model_state_dict = module.state_dict()
+ load_state_dict = torch.load(path, map_location=device)
+ skip_keys= list()
+ for name, param in load_state_dict.items():
+ if name in model_state_dict:
+ try:
+ model_state_dict[name].copy_(param)
+ except:
+ skip_keys.append(name)
+ if len(skip_keys) > 0:
+ print('Warning: Skip key(s) in state_dict: {}. '.format(
+ ', '.join('"{}"'.format(k) for k in skip_keys)))
diff --git a/pygmtools/utils.py b/pygmtools/utils.py
index f2c75526..e496cca5 100644
--- a/pygmtools/utils.py
+++ b/pygmtools/utils.py
@@ -26,6 +26,8 @@
import wget
import numpy as np
import pygmtools
+import networkx as nx
+import urllib.request
NOT_IMPLEMENTED_MSG = \
'The backend function for {} is not implemented. ' \
@@ -1200,6 +1202,7 @@ def _mm(input1, input2, backend=None):
)
return fn(*args)
+
def download(filename, url, md5=None, retries=10, to_cache=True):
r"""
Check if content exits. If not, download the content to ``/pygmtools/``. ````
@@ -1207,7 +1210,7 @@ def download(filename, url, md5=None, retries=10, to_cache=True):
:param filename: the destination file name
:param url: the url
:param md5: (optional) the md5sum to verify the content. It should match the result of ``md5sum file`` on Linux.
- :param retries: (default: 5) max number of retries
+ :param retries: (default: 10) max number of retries
:return: the full path to the file: ``/pygmtools/``
"""
if retries <= 0:
@@ -1220,7 +1223,7 @@ def download(filename, url, md5=None, retries=10, to_cache=True):
filename = os.path.join(dirs, filename)
if not os.path.exists(filename):
print(f'\nDownloading to {filename}...')
- if retries % 2 == 1:
+ if retries % 3 == 1:
try:
down_res = requests.get(url, stream=True)
file_size = int(down_res.headers.get('Content-Length', 0))
@@ -1230,11 +1233,17 @@ def download(filename, url, md5=None, retries=10, to_cache=True):
except requests.exceptions.ConnectionError as err:
print('Warning: Network error. Retrying...\n', err)
return download(filename, url, md5, retries - 1)
- else:
+ elif retries % 3 == 2:
try:
wget.download(url,out=filename)
except:
return download(filename, url, md5, retries - 1)
+ else:
+ try:
+ urllib.request.urlretrieve(url, filename)
+ except:
+ return download(filename, url, md5, retries - 1)
+
if md5 is not None:
md5_returned = _get_md5(filename)
if md5 != md5_returned:
@@ -1244,6 +1253,7 @@ def download(filename, url, md5=None, retries=10, to_cache=True):
return download(filename, url, md5, retries - 1)
return filename
+
def _get_md5(filename):
hash_md5 = hashlib.md5()
chunk = 8192
@@ -1255,3 +1265,254 @@ def _get_md5(filename):
hash_md5.update(buffer)
md5_returned = hash_md5.hexdigest()
return md5_returned
+
+
+###################################################
+# Support NetworkX and GraphML formats #
+###################################################
+
+
+def build_aff_mat_from_networkx(G1:nx.Graph, G2:nx.Graph, node_aff_fn=None, edge_aff_fn=None, backend=None):
+ r"""
+ Convert networkx object to Adjacency matrix
+
+ :param G1: networkx object, whose type must be networkx.Graph
+ :param G2: networkx object, whose type must be networkx.Graph
+ :param node_aff_fn: (default: inner_prod_aff_fn) the node affinity function with the characteristic
+ ``node_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two node feature tensors and
+ outputs the node-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an
+ example.
+ :param edge_aff_fn: (default: inner_prod_aff_fn) the edge affinity function with the characteristic
+ ``edge_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two edge feature tensors and
+ outputs the edge-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an
+ example.
+ :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
+ :return: the affinity matrix corresponding to the networkx object G1 and G2
+
+ .. dropdown:: Example
+
+ ::
+
+ >>> import networkx as nx
+ >>> import pygmtools as pygm
+ >>> pygm.BACKEND = 'numpy'
+
+ # Generate networkx images
+ >>> G1 = nx.DiGraph()
+ >>> G1.add_weighted_edges_from([(1, 2, 0.5), (2, 3, 0.8), (3, 4, 0.7)])
+ >>> G2 = nx.DiGraph()
+ >>> G2.add_weighted_edges_from([(1, 2, 0.3), (2, 3, 0.6), (3, 4, 0.9), (4, 5, 0.4)])
+
+ # Obtain Affinity Matrix
+ >>> K = pygm.utils.build_aff_mat_from_networkx(G1, G2)
+ >>> K.shape
+ (20,20)
+
+ # The affinity matrices K can be further processed by GM solvers
+ """
+ if backend is None:
+ backend = pygmtools.BACKEND
+ A1 = from_numpy(np.asarray(from_networkx(G1)))
+ A2 = from_numpy(np.asarray(from_networkx(G2)))
+ conn1, edge1 = dense_to_sparse(A1, backend=backend)
+ conn2, edge2 = dense_to_sparse(A2, backend=backend)
+ K = build_aff_mat(None, edge1, conn1, None, edge2, conn2, node_aff_fn=node_aff_fn, edge_aff_fn=edge_aff_fn, backend=backend)
+ return K
+
+
+def build_aff_mat_from_graphml(G1_path, G2_path, node_aff_fn=None, edge_aff_fn=None, backend=None):
+ r"""
+ Convert networkx object to Adjacency matrix
+
+ :param G1_path: The file path of the graphml object
+ :param G2_path: The file path of the graphml object
+ :param node_aff_fn: (default: inner_prod_aff_fn) the node affinity function with the characteristic
+ ``node_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two node feature tensors and
+ outputs the node-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an
+ example.
+ :param edge_aff_fn: (default: inner_prod_aff_fn) the edge affinity function with the characteristic
+ ``edge_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two edge feature tensors and
+ outputs the edge-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an
+ example.
+ :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
+ :return: the affinity matrix corresponding to the graphml object G1 and G2
+
+
+ .. dropdown:: Example
+
+ ::
+
+ >>> import pygmtools as pygm
+ >>> pygm.BACKEND = 'numpy'
+
+ # example file (.graphml) path
+ >>> G1_path = 'examples/data/graph1.graphml'
+ >>> G2_path = 'examples/data/graph2.graphml'
+
+ # Obtain Affinity Matrix
+ >>> K = pygm.utils.build_aff_mat_from_graphml(G1_path, G2_path)
+ >>> K.shape
+ (121,121)
+
+ # The affinity matrices K can be further processed by GM solvers
+ """
+ if backend is None:
+ backend = pygmtools.BACKEND
+ A1 = from_numpy(np.asarray(from_graphml(G1_path)))
+ A2 = from_numpy(np.asarray(from_graphml(G2_path)))
+ conn1, edge1 = dense_to_sparse(A1, backend=backend)
+ conn2, edge2 = dense_to_sparse(A2, backend=backend)
+ K = build_aff_mat(None, edge1, conn1, None, edge2, conn2, node_aff_fn=node_aff_fn, edge_aff_fn=edge_aff_fn, backend=backend)
+ return K
+
+
+def from_networkx(G:nx.Graph):
+ r"""
+ Convert networkx object to Adjacency matrix
+
+ :param G: networkx object, whose type must be networkx.Graph
+ :return: the adjacency matrix corresponding to the networkx object
+
+ .. dropdown:: Example
+
+ ::
+
+ >>> import networkx as nx
+ >>> import pygmtools as pygm
+ >>> pygm.BACKEND = 'numpy'
+
+ # Generate networkx graphs
+ >>> G1 = nx.DiGraph()
+ >>> G1.add_weighted_edges_from([(1, 2, 0.5), (2, 3, 0.8), (3, 4, 0.7)])
+ >>> G2 = nx.DiGraph()
+ >>> G2.add_weighted_edges_from([(1, 2, 0.3), (2, 3, 0.6), (3, 4, 0.9), (4, 5, 0.4)])
+
+ # Obtain Adjacency matrix
+ >>> pygm.utils.from_networkx(G1)
+ matrix([[0. , 0.5, 0. , 0. ],
+ [0. , 0. , 0.8, 0. ],
+ [0. , 0. , 0. , 0.7],
+ [0. , 0. , 0. , 0. ]])
+
+ >>> pygm.utils.from_networkx(G2)
+ matrix([[0. , 0.3, 0. , 0. , 0. ],
+ [0. , 0. , 0.6, 0. , 0. ],
+ [0. , 0. , 0. , 0.9, 0. ],
+ [0. , 0. , 0. , 0. , 0.4],
+ [0. , 0. , 0. , 0. , 0. ]])
+
+ """
+ is_directed = isinstance(G, nx.DiGraph)
+ adj_matrix = nx.to_numpy_matrix(G,nodelist=G.nodes()) if is_directed else nx.to_numpy_matrix(G)
+ return adj_matrix
+
+
+def to_networkx(adj_matrix, backend=None):
+ """
+ Convert adjacency matrix to NetworkX object
+
+ :param adj_matrix: the adjacency matrix to convert
+ :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
+ :return: the NetworkX object corresponding to the adjacency matrix
+
+ .. dropdown:: Example
+
+ ::
+
+ >>> import networkx as nx
+ >>> import pygmtools as pygm
+ >>> pygm.BACKEND = 'numpy'
+
+ # Generate adjacency matrix
+ >>> adj_matrix = np.random.random(size=(4,4))
+
+ # Obtain NetworkX object
+ >>> pygm.utils.to_networkx(adj_matrix)
+
+ """
+ if backend is None:
+ backend = pygmtools.BACKEND
+ adj_matrix = to_numpy(adj_matrix, backend=backend)
+
+ if adj_matrix.ndim == 3 and adj_matrix.shape[0] == 1:
+ adj_matrix.squeeze(0)
+ assert adj_matrix.ndim == 2, 'Request the dimension of adj_matrix is 2'
+
+ G = nx.DiGraph() if np.any(adj_matrix != adj_matrix.T) else nx.Graph()
+ G.add_nodes_from(range(adj_matrix.shape[0]))
+ for i, j in zip(*np.where(adj_matrix)):
+ G.add_edge(i, j, weight=adj_matrix[i, j])
+ return G
+
+
+def from_graphml(filename):
+ r"""
+ Convert graphml object to Adjacency matrix
+
+ :param filename: graphml file path
+ :return: the adjacency matrix corresponding to the graphml object
+
+ .. dropdown:: Example
+
+ ::
+
+ >>> import pygmtools as pygm
+ >>> pygm.BACKEND = 'numpy'
+
+ # example file (.graphml) path
+ >>> G1_path = 'examples/data/graph1.graphml'
+ >>> G2_path = 'examples/data/graph2.graphml'
+
+ # Obtain Adjacency matrix
+ >>> G1 = pygm.utils.from_graphml(G1_path)
+ >>> G1.shape
+ (11,11)
+
+ >>> G1 = pygm.utils.from_graphml(G2_path)
+ >>> G2.shape
+ (11,11)
+ """
+ if not filename.endswith('.graphml'):
+ raise ValueError("File name should end with '.graphml'")
+ if not os.path.isfile(filename):
+ raise ValueError("File not found: {}".format(filename))
+ return from_networkx(nx.read_graphml(filename))
+
+
+def to_graphml(adj_matrix, filename, backend=None):
+ r"""
+ Write an adjacency matrix to a GraphML file
+
+ :param adj_matrix: numpy.ndarray, the adjacency matrix to write
+ :param filename: str, the name of the output file
+ :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation.
+
+ .. dropdown:: Example
+
+ ::
+
+ >>> import pygmtools as pygm
+ >>> import numpy as np
+ >>> pygm.BACKEND = 'numpy'
+
+ # Generate adjacency matrix
+ >>> adj_matrix = np.random.random(size=(4,4))
+ >>> filename = 'examples/data/test.graphml'
+ >>> adj_matrix
+ array([[0.29440151, 0.66468829, 0.05403941, 0.85887567],
+ [0.48120964, 0.01429095, 0.73536659, 0.02962113],
+ [0.3815578 , 0.93356234, 0.01332568, 0.61149257],
+ [0.15422904, 0.64656912, 0.93219422, 0.784769 ]])
+
+ # Write GraphML file
+ >>> pygm.utils.to_graphml(adj_matrix, filename)
+
+ # Check the generated GraphML file
+ >>> pygm.utils.from_graphml(filename)
+ array([[0.29440151, 0.66468829, 0.05403941, 0.85887567],
+ [0.48120964, 0.01429095, 0.73536659, 0.02962113],
+ [0.3815578 , 0.93356234, 0.01332568, 0.61149257],
+ [0.15422904, 0.64656912, 0.93219422, 0.784769 ]])
+ """
+ nx.write_graphml(to_networkx(adj_matrix, backend), filename)
+
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 5f1f3385..4a11a3af 100644
--- a/setup.py
+++ b/setup.py
@@ -9,14 +9,40 @@
import sys
from shutil import rmtree
import re
-
+import platform
from setuptools import find_packages, setup, Command
+import tarfile
+import distro
+from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
+import shutil
+
def get_property(prop, project):
result = re.search(r'{}\s*=\s*[\'"]([^\'"]*)[\'"]'.format(prop), open(project + '/__init__.py').read())
return result.group(1)
+def get_os_and_python_version():
+ system = platform.system()
+ python_version = ".".join(map(str, sys.version_info[:2]))
+ if system.lower() == "windows":
+ os_version = "windows"
+ elif system.lower() == "darwin":
+ os_version = "macos"
+ elif system.lower() == "linux":
+ os_version = distro.name().lower()
+ else:
+ raise ValueError("Unknown System")
+ if (python_version == '3.11'):
+ python_version = '3.10'
+ return os_version, python_version
+
+
+def untar_file(tar_file_path, extract_folder_path):
+ with tarfile.open(tar_file_path, 'r:gz') as tarObj:
+ tarObj.extractall(extract_folder_path)
+
+
# Package meta-data.
NAME = 'pygmtools'
DESCRIPTION = 'pygmtools provides graph matching solvers in Python API and supports numpy and pytorch backends. ' \
@@ -24,11 +50,21 @@ def get_property(prop, project):
URL = 'https://pygmtools.readthedocs.io/'
AUTHOR = get_property('__author__', NAME)
VERSION = get_property('__version__', NAME)
-
REQUIRED = [
'requests>=2.25.1', 'scipy>=1.4.1', 'Pillow>=7.2.0', 'numpy>=1.18.5', 'easydict>=1.7', 'appdirs>=1.4.4', 'tqdm>=4.64.1','wget>=3.2'
]
-
+FILE={'windows':{ '3.7':'a_star.cp37-win_amd64.pyd',
+ '3.8':'a_star.cp38-win_amd64.pyd',
+ '3.9':'a_star.cp39-win_amd64.pyd',
+ '3.10':'a_star.cp310-win_amd64.pyd'},
+ 'macos' :{ '3.7':'a_star.cpython-37m-darwin.so',
+ '3.8':'a_star.cpython-38-darwin.so',
+ '3.9':'a_star.cpython-39-darwin.so',
+ '3.10':'a_star.cpython-310-darwin.so'},
+ 'ubuntu' :{ '3.7':'a_star.cpython-37m-x86_64-linux-gnu.so',
+ '3.8':'a_star.cpython-38-x86_64-linux-gnu.so',
+ '3.9':'a_star.cpython-39-x86_64-linux-gnu.so',
+ '3.10':'a_star.cpython-310-x86_64-linux-gnu.so'}}
EXTRAS = {}
here = os.path.abspath(os.path.dirname(__file__))
@@ -50,6 +86,18 @@ def get_property(prop, project):
else:
about['__version__'] = VERSION
+if os.path.exists(os.path.join(NAME,'astar/a_star.tar.gz')):
+ untar_file(os.path.join(NAME,'astar/a_star.tar.gz'),os.path.join(NAME,'astar'))
+
+
+class CustomBdistWheelCommand(_bdist_wheel):
+ def run(self):
+ os_version, python_version = get_os_and_python_version()
+ dynamic_link = FILE[os_version][python_version]
+ shutil.copy2(os.path.join(NAME, 'astar', dynamic_link), os.path.join(NAME, dynamic_link))
+ shutil.rmtree(os.path.join(NAME, 'astar'))
+ _bdist_wheel.run(self)
+
class UploadCommand(Command):
"""Support setup.py upload."""
@@ -97,12 +145,13 @@ def run(self):
author=AUTHOR,
url=URL,
packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]),
+ package_data={NAME: ['astar/*','*.pyd','*.so']},
install_requires=REQUIRED,
extras_require=EXTRAS,
include_package_data=True,
license='Mulan PSL v2',
python_requires='>=3.7',
- classifiers=(
+ classifiers=[
'License :: OSI Approved :: Mulan Permissive Software License v2 (MulanPSL-2.0)',
'Programming Language :: Python :: 3 :: Only',
'Operating System :: OS Independent',
@@ -111,9 +160,10 @@ def run(self):
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Image Recognition',
'Topic :: Scientific/Engineering :: Mathematics',
- ),
+ ],
# $ setup.py publish support.
cmdclass={
'upload': UploadCommand,
+ 'bdist_wheel': CustomBdistWheelCommand,
},
)
diff --git a/tests/requirements.txt b/tests/requirements.txt
index 258519a7..a1b9cfe1 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -10,4 +10,8 @@ tqdm
jittor==1.3.5.37
appdirs>=1.4.4
tensorflow==2.9.3
+distro
+cython
wget
+networkx==2.8.8
+
diff --git a/tests/requirements_win_mac.txt b/tests/requirements_win_mac.txt
index 948db8d6..5b0fcd8d 100644
--- a/tests/requirements_win_mac.txt
+++ b/tests/requirements_win_mac.txt
@@ -10,4 +10,7 @@ tqdm
appdirs>=1.4.4
tensorflow==2.9.3
mindspore==1.10.0
+distro
+cython
wget
+networkx==2.8.8
\ No newline at end of file
diff --git a/tests/test_a_star/a_star.pyx b/tests/test_a_star/a_star.pyx
new file mode 100644
index 00000000..619ead7f
--- /dev/null
+++ b/tests/test_a_star/a_star.pyx
@@ -0,0 +1,166 @@
+# distutils: language = c++
+import torch
+import numpy as np
+cimport cython
+cimport numpy as np
+#from libcpp.queue cimport priority_queue
+from libcpp.vector cimport vector
+from libcpp.pair cimport pair
+from libcpp cimport bool
+
+cdef extern from "priority_queue.hpp":
+ cdef cppclass TreeNode:
+ TreeNode()
+ TreeNode(long)
+ #TreeNode(vector[pair[long, long]], double, long)
+ pair[vector[long], vector[long]] x_indices
+ double gplsh
+ long idx
+
+ cdef cppclass tree_node_priority_queue:
+ tree_node_priority_queue(...) # get Cython to accept any arguments and let C++ deal with getting them right
+ void push(TreeNode)
+ TreeNode top()
+ void pop()
+ bool empty()
+ long size()
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def a_star(
+ data,
+ k,
+ vector[long] ns_1,
+ vector[long] ns_2,
+ net_pred_func,
+ heuristic_func,
+ bool net_pred=True,
+ long beam_width=0,
+ double trust_fact=1.,
+ long no_pred_size=0,
+):
+ # declare static dtypes
+ cdef long batch_num, b, n1, n2, _n2, ns_1b, ns_2b, max_ns_1, max_ns_2, extra_n2_cnt
+ cdef vector[long] tree_size
+ cdef double h_p, g_p
+ cdef vector[tree_node_priority_queue] open_set
+ cdef tree_node_priority_queue cur_set
+ cdef TreeNode selected, new_node
+ cdef vector[bool] stop_flags
+ cdef bool flag
+
+ batch_num = k.shape[0]
+
+ max_ns_1 = max(ns_1)
+ max_ns_2 = max(ns_2)
+
+ open_set = vector[tree_node_priority_queue](batch_num)
+ tree_size = vector[long](batch_num)
+ for b in range(batch_num):
+ open_set[b].push(TreeNode())
+ ret_x = torch.zeros(batch_num, max_ns_1+1, max_ns_2+1, device=k.device)
+ x_dense = torch.zeros(max_ns_1+1, max_ns_2+1, device=k.device)
+ stop_flags = vector[bool](batch_num, 0)
+ while not all(stop_flags):
+ for b in range(batch_num):
+ ns_1b = ns_1[b]
+ ns_2b = ns_2[b]
+
+ if stop_flags[b] == 1:
+ continue
+
+ selected = open_set[b].top()
+ open_set[b].pop()
+ #selected_x_indices = torch.tensor(selected.x_indices, dtype=torch.long).reshape(-1, 2)
+ if selected.idx == ns_1b:
+ stop_flags[b] = 1
+ #indices = selected_x_indices
+ #v = torch.ones(indices.shape[0], device=k.device)
+ #x = torch.sparse.FloatTensor(indices.t(), v, x_size).to_dense()
+ ret_x[b][selected.x_indices] = 1
+ continue
+
+ if beam_width > 0:
+ cur_set = tree_node_priority_queue()
+ flag = False
+ for n2 in range(ns_2b + 1):
+ if n2 != ns_2b and is_in(n2, selected.x_indices.second):
+ continue
+ if selected.idx + 1 == ns_1b:
+ flag = True
+ extra_n2_cnt = 0
+ for _n2 in range(ns_2b):
+ if _n2 != n2 and not is_in(_n2, selected.x_indices.second):
+ extra_n2_cnt += 1
+ new_node = TreeNode(ns_1b + extra_n2_cnt)
+ n1 = 0
+ for _ in range(selected.idx):
+ new_node.x_indices.first[n1] = selected.x_indices.first[n1]
+ new_node.x_indices.second[n1] = selected.x_indices.second[n1]
+ n1 += 1
+ new_node.x_indices.first[n1] = selected.idx
+ new_node.x_indices.second[n1] = n2
+ n1 += 1
+ for _n2 in range(ns_2b):
+ if _n2 != n2 and not is_in(_n2, selected.x_indices.second):
+ new_node.x_indices.first[n1] = ns_1b
+ new_node.x_indices.second[n1] = _n2
+ n1 += 1
+ else:
+ new_node = TreeNode(selected.idx + 1)
+ n1 = 0
+ for _ in range(selected.idx):
+ new_node.x_indices.first[n1] = selected.x_indices.first[n1]
+ new_node.x_indices.second[n1] = selected.x_indices.second[n1]
+ n1 += 1
+ new_node.x_indices.first[n1] = selected.idx
+ new_node.x_indices.second[n1] = n2
+ n1 += 1
+
+ x_dense[:] = 0
+ x_dense[new_node.x_indices] = 1
+
+ g_p = comp_ged(x_dense, k[b])
+
+ if net_pred:
+ if selected.idx + 1 == ns_1b or trust_fact <= 0. or ns_1b - (selected.idx + 1) < no_pred_size:
+ h_p = 0
+ else:
+ h_p = net_pred_func(data, x_dense)
+ else:
+ if selected.idx + 1 == ns_1b or trust_fact <= 0. or ns_1b - (selected.idx + 1) < no_pred_size:
+ h_p = 0
+ else:
+ h_p = heuristic_func(k[b], ns_1b, ns_2b, x_dense)
+
+ new_node.gplsh = g_p + h_p * trust_fact
+ new_node.idx = selected.idx + 1
+
+ if beam_width > 0:
+ cur_set.push(new_node)
+ else:
+ open_set[b].push(new_node)
+ tree_size[b] += 1
+ if(flag):
+ break
+
+ if beam_width > 0:
+ for i in range(min(beam_width, cur_set.size())):
+ open_set[b].push(cur_set.top())
+ cur_set.pop()
+ tree_size[b] += 1
+
+ return ret_x, tree_size
+
+
+cdef double comp_ged(_x, _k):
+ return torch.mm(torch.mm(_x.reshape( 1, -1), _k), _x.reshape( -1, 1))
+
+cdef bool is_in (long inp, vector[long] vec):
+ cdef unsigned long i
+ cdef bool ret = False
+ for i in range(vec.size()):
+ if inp == vec[i]:
+ ret = True
+ break
+ return ret
diff --git a/tests/test_a_star/a_star_setup.py b/tests/test_a_star/a_star_setup.py
new file mode 100644
index 00000000..2e76312e
--- /dev/null
+++ b/tests/test_a_star/a_star_setup.py
@@ -0,0 +1,18 @@
+from setuptools import setup, Extension
+from Cython.Build import cythonize
+import numpy as np
+from glob import glob
+setup(
+ name='a-star function',
+ ext_modules=cythonize(
+ Extension(
+ 'a_star',
+ glob('*.pyx'),
+ include_dirs=[np.get_include(),"."],
+ extra_compile_args=["-std=c++11"],
+ extra_link_args=["-std=c++11"],
+ ),
+ language_level = "3",
+ ),
+ zip_safe=False,
+)
diff --git a/tests/test_a_star/prepare_for_test.py b/tests/test_a_star/prepare_for_test.py
new file mode 100644
index 00000000..16096693
--- /dev/null
+++ b/tests/test_a_star/prepare_for_test.py
@@ -0,0 +1,26 @@
+import os
+import glob
+import shutil
+
+ori_dir = os.getcwd()
+os.chdir('tests/test_a_star')
+
+try:
+ os.system("python a_star_setup.py build_ext --inplace")
+except:
+ os.system("python3 a_star_setup.py build_ext --inplace")
+
+current_dir = os.getcwd()
+
+ext_files = glob.glob(os.path.join(current_dir, '*.pyd')) + \
+ glob.glob(os.path.join(current_dir, '*.so'))
+
+if len(ext_files) == 0:
+ raise ValueError("there is no .pyd or .so")
+elif len(ext_files) > 1:
+ raise ValueError("too many files end with .pyd or .so")
+else:
+ target_dir = os.path.abspath(os.path.join(current_dir,'..','..','pygmtools'))
+ shutil.copy(ext_files[0], target_dir)
+
+os.chdir(ori_dir)
diff --git a/tests/test_a_star/priority_queue.hpp b/tests/test_a_star/priority_queue.hpp
new file mode 100644
index 00000000..7c2905d6
--- /dev/null
+++ b/tests/test_a_star/priority_queue.hpp
@@ -0,0 +1,41 @@
+#include
+#include
+
+struct TreeNode
+{
+ std::pair, std::vector > x_indices;
+ double gplsh;
+ long idx;
+ TreeNode();
+ TreeNode(const int &);
+ TreeNode(const std::pair, std::vector > &, const double &, const long &);
+ bool operator>(const TreeNode &) const;
+};
+
+TreeNode::TreeNode()
+{
+ this->x_indices = std::pair, std::vector >();
+ this->gplsh = 0;
+ this->idx = 0;
+}
+
+TreeNode::TreeNode(const int & len)
+{
+ this->x_indices = std::pair, std::vector >(std::vector(len), std::vector(len));
+ this->gplsh = 0;
+ this->idx = 0;
+}
+
+TreeNode::TreeNode(const std::pair, std::vector > &x_indices, const double &gplsh, const long &idx)
+{
+ this->x_indices = x_indices;
+ this->gplsh = gplsh;
+ this->idx = idx;
+}
+
+bool TreeNode::operator>(const TreeNode &c) const
+{
+ return this->gplsh > c.gplsh;
+}
+
+using tree_node_priority_queue = std::priority_queue, std::greater >;
diff --git a/tests/test_classic_solvers.py b/tests/test_classic_solvers.py
index a5dbf9a6..fbd48ee2 100644
--- a/tests/test_classic_solvers.py
+++ b/tests/test_classic_solvers.py
@@ -27,6 +27,8 @@
def get_backends(backend):
if backend == "all":
backends = ['pytorch', 'numpy', 'paddle', 'jittor', 'tensorflow'] if os_name == 'Linux' else ['pytorch', 'numpy', 'paddle', 'tensorflow']
+ elif backend == '':
+ backends = ['pytorch']
else:
backends = ["pytorch", backend]
return backends
@@ -208,6 +210,110 @@ def _test_classic_solver_on_linear_assignment(num_nodes1, num_nodes2, node_feat_
last_X = pygm.utils.to_numpy(_X)
+# The testing function for a_star
+def _test_astar(graph_num_nodes, node_feat_dim, solver_func, matrix_params, backends):
+ if backends[0] != 'pytorch':
+ backends.insert(0, 'pytorch') # force pytorch as the reference backend
+ backends = ['pytorch'] # Due to currently only supporting pytorch, testing is only conducted under pytorch
+ batch_size = len(graph_num_nodes)
+
+ # Generate isomorphic graphs
+ pygm.BACKEND = 'pytorch'
+ torch.manual_seed(0)
+ X_gt, A1, A2, F1, F2, = [], [], [], [], [],
+ for b, num_node in enumerate(graph_num_nodes):
+ As_b, X_gt_b, Fs_b = pygm.utils.generate_isomorphic_graphs(num_node, node_feat_dim=node_feat_dim)
+ Fs_b = Fs_b - 0.5
+ X_gt.append(X_gt_b)
+ A1.append(As_b[0])
+ A2.append(As_b[1])
+ F1.append(Fs_b[0])
+ F2.append(Fs_b[1])
+ n1 = torch.tensor(graph_num_nodes, dtype=torch.int)
+ n2 = torch.tensor(graph_num_nodes, dtype=torch.int)
+ A1, A2, F1, F2, X_gt = (pygm.utils.build_batch(_) for _ in (A1, A2, F1, F2, X_gt))
+ if batch_size > 1:
+ A1, A2, F1, F2, n1, n2, X_gt = data_to_numpy(A1, A2, F1, F2, n1, n2, X_gt)
+ else:
+ A1, A2, F1, F2, n1, n2, X_gt = data_to_numpy(
+ A1.squeeze(0), A2.squeeze(0), F1.squeeze(0), F2.squeeze(0), n1, n2, X_gt.squeeze(0)
+ )
+
+ # call the solver
+ total = 1
+ for val in matrix_params.values():
+ total *= len(val)
+ for values in tqdm(itertools.product(*matrix_params.values()), total=total):
+ solver_param_dict = {}
+ for k, v in zip(matrix_params.keys(), values):
+ solver_param_dict[k] = v
+
+ last_X = None
+ for working_backend in backends:
+ pygm.BACKEND = working_backend
+ _A1, _A2, _F1, _F2, _n1, _n2 = data_from_numpy(A1, A2, F1, F2, n1, n2)
+ _X1 = solver_func(_F1, _F2, _A1, _A2, _n1, _n2, **solver_param_dict)
+
+ if last_X is not None:
+ assert np.abs(pygm.utils.to_numpy(_X1) - last_X).sum() < 5e-3, \
+ f"Incorrect GM solution for {working_backend}; " \
+ f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}"
+
+ last_X = pygm.utils.to_numpy(_X1)
+ accuracy = (pygm.utils.to_numpy(pygm.hungarian(_X1, _n1, _n2)) * X_gt).sum() / X_gt.sum()
+ assert accuracy == 1, f"GM is inaccurate for {working_backend}, accuracy={accuracy:.4f}; " \
+ f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}"
+
+
+# The testing function for networkx
+def _test_networkx(graph_num_nodes, backends):
+ """
+ Test the RRWM algorithm on pairs of isomorphic graphs using NetworkX
+
+ :param graph_num_nodes: list, the numbers of nodes in the graphs to test
+ """
+ for working_backend in backends:
+ pygm.BACKEND = working_backend
+ for num_node in tqdm(graph_num_nodes):
+ As_b, X_gt = pygm.utils.generate_isomorphic_graphs(num_node)
+ X_gt = pygm.utils.to_numpy(X_gt, backend=working_backend)
+ A1 = As_b[0]
+ A2 = As_b[1]
+ G1 = pygm.utils.to_networkx(A1)
+ G2 = pygm.utils.to_networkx(A2)
+ K = pygm.utils.build_aff_mat_from_networkx(G1, G2)
+ X = pygm.rrwm(K, n1=num_node, n2=num_node)
+ accuracy = (pygm.utils.to_numpy(pygm.hungarian(X, num_node, num_node)) * X_gt).sum() / X_gt.sum()
+ assert accuracy == 1, f'When testing the networkx function with rrwm algorithm, there is an error in accuracy, \
+ and the accuracy is {accuracy}, the num_node is {num_node},.'
+
+
+# The testing fuction for graphml
+def _test_graphml(graph_num_nodes, backends):
+ """
+ Test the RRWM algorithm on pairs of isomorphic graphs using graphml
+
+ :param graph_num_nodes: list, the numbers of nodes in the graphs to test
+ """
+ filename = 'examples/data/test_graphml_{}.graphml'
+ filename_1 = filename.format(1)
+ filename_2 = filename.format(2)
+ for working_backend in backends:
+ pygm.BACKEND = working_backend
+ for num_node in tqdm(graph_num_nodes):
+ As_b, X_gt = pygm.utils.generate_isomorphic_graphs(num_node)
+ X_gt = pygm.utils.to_numpy(X_gt, backend=working_backend)
+ A1 = As_b[0]
+ A2 = As_b[1]
+ pygm.utils.to_graphml(A1, filename_1, backend=working_backend)
+ pygm.utils.to_graphml(A2, filename_2, backend=working_backend)
+ K = pygm.utils.build_aff_mat_from_graphml(filename_1, filename_2)
+ X = pygm.rrwm(K, n1=num_node, n2=num_node)
+ accuracy = (pygm.utils.to_numpy(pygm.hungarian(X, num_node, num_node)) * X_gt).sum() / X_gt.sum()
+ assert accuracy == 1, f'When testing the graphml function with rrwm algorithm, there is an error in accuracy, \
+ and the accuracy is {accuracy}, the num_node is {num_node},.'
+
+
def test_hungarian(get_backend):
backends = get_backends(get_backend)
_test_classic_solver_on_linear_assignment(list(range(10, 30, 2)), list(range(30, 10, -2)), 10, pygm.hungarian, {
@@ -343,6 +449,36 @@ def test_ipfp(get_backend):
'edge_aff_fn': [functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.)],
'node_aff_fn': [functools.partial(pygm.utils.gaussian_aff_fn, sigma=.1)]
}, backends)
+
+
+def test_astar(get_backend):
+ backends = get_backends(get_backend)
+ # heuristic_prediction
+ args1 = (list(range(10, 16, 2)), 10, pygm.astar,{
+ "beam_width": [0, 1, 2],
+ "trust_fact": [0.9, 0.95, 1.0],
+ "no_pred_size": [0, 1],
+ }, backends)
+
+ # non-batched input
+ args2 = ([10], 10, pygm.astar,{
+ "beam_width": [0, 1, 2],
+ "trust_fact": [0.9, 0.95, 1.0],
+ "no_pred_size": [0, 1],
+ }, backends)
+
+ _test_astar(*args1)
+ _test_astar(*args2)
+
+
+def test_networkx():
+ backends = ['pytorch', 'numpy']
+ _test_networkx(list(range(10, 30, 2)), backends=backends)
+
+
+def test_graphml():
+ backends = ['pytorch', 'numpy']
+ _test_graphml(list(range(10, 30, 2)), backends=backends)
if __name__ == '__main__':
@@ -351,3 +487,6 @@ def test_ipfp(get_backend):
test_rrwm('all')
test_sm('all')
test_ipfp('all')
+ test_astar('')
+ test_networkx()
+ test_graphml()
diff --git a/tests/test_neural_solvers.py b/tests/test_neural_solvers.py
index d6fe7a42..64b1e28b 100644
--- a/tests/test_neural_solvers.py
+++ b/tests/test_neural_solvers.py
@@ -136,6 +136,71 @@ def _test_neural_solver_on_isomorphic_graphs(graph_num_nodes, node_feat_dim, sol
f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}"
+# The testing function for genn_astar
+def _test_genn_astar(graph_num_nodes, node_feat_dim, solver_func, matrix_params, backends):
+ if backends[0] != 'pytorch':
+ backends.insert(0, 'pytorch') # force pytorch as the reference backend
+ backends = ['pytorch'] # Due to currently only supporting pytorch, testing is only conducted under pytorch
+ batch_size = len(graph_num_nodes)
+
+ # Generate isomorphic graphs
+ pygm.BACKEND = 'pytorch'
+ torch.manual_seed(0)
+ X_gt, A1, A2, F1, F2, = [], [], [], [], [],
+ for b, num_node in enumerate(graph_num_nodes):
+ As_b, X_gt_b, Fs_b = pygm.utils.generate_isomorphic_graphs(num_node, node_feat_dim=node_feat_dim)
+ Fs_b = Fs_b - 0.5
+ X_gt.append(X_gt_b)
+ A1.append(As_b[0])
+ A2.append(As_b[1])
+ F1.append(Fs_b[0])
+ F2.append(Fs_b[1])
+ n1 = torch.tensor(graph_num_nodes, dtype=torch.int)
+ n2 = torch.tensor(graph_num_nodes, dtype=torch.int)
+ A1, A2, F1, F2, X_gt = (pygm.utils.build_batch(_) for _ in (A1, A2, F1, F2, X_gt))
+ if batch_size > 1:
+ A1, A2, F1, F2, n1, n2, X_gt = data_to_numpy(A1, A2, F1, F2, n1, n2, X_gt)
+ else:
+ A1, A2, F1, F2, n1, n2, X_gt = data_to_numpy(
+ A1.squeeze(0), A2.squeeze(0), F1.squeeze(0), F2.squeeze(0), n1, n2, X_gt.squeeze(0)
+ )
+
+ # call the solver
+ total = 1
+ for val in matrix_params.values():
+ total *= len(val)
+ for values in tqdm(itertools.product(*matrix_params.values()), total=total):
+ solver_param_dict = {}
+ for k, v in zip(matrix_params.keys(), values):
+ solver_param_dict[k] = v
+
+ last_X = None
+ for working_backend in backends:
+ pygm.BACKEND = working_backend
+ _A1, _A2, _F1, _F2, _n1, _n2 = data_from_numpy(A1, A2, F1, F2, n1, n2)
+ _X1, net = solver_func(_F1, _F2, _A1, _A2, _n1, _n2, return_network=True, **solver_param_dict)
+ _X2 = solver_func(_F1, _F2, _A1, _A2, _n1, _n2, network=net, **solver_param_dict)
+ net2 = pygm.utils.get_network(solver_func, **solver_param_dict)
+ assert type(net) == type(net2)
+
+ assert np.abs(pygm.utils.to_numpy(_X1) - pygm.utils.to_numpy(_X2)).sum() < 1e-4, \
+ f"GM result inconsistent for predefined network object. backend={working_backend}; " \
+ f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}"
+
+ if 'pretrain' in solver_param_dict and solver_param_dict['pretrain'] is None:
+ _X1 = pygm.hungarian(_X1, _n1, _n2)
+
+ if last_X is not None:
+ assert np.abs(pygm.utils.to_numpy(_X1) - last_X).sum() < 5e-3, \
+ f"Incorrect GM solution for {working_backend}; " \
+ f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}"
+
+ last_X = pygm.utils.to_numpy(_X1)
+ accuracy = (pygm.utils.to_numpy(pygm.hungarian(_X1, _n1, _n2)) * X_gt).sum() / X_gt.sum()
+ assert accuracy == 1, f"GM is inaccurate for {working_backend}, accuracy={accuracy:.4f}; " \
+ f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}"
+
+
def test_pca_gm():
_test_neural_solver_on_isomorphic_graphs(list(range(10, 30, 2)), 1024, pygm.pca_gm, 'individual-graphs', {
'pretrain': ['voc', 'willow', 'voc-all'],
@@ -186,6 +251,7 @@ def test_cie():
'pretrain': [None],
}, backends)
+
def test_ngm():
_test_neural_solver_on_isomorphic_graphs(list(range(10, 30, 2)), 1024, pygm.ngm, 'lawler-qap', {
'edge_aff_fn': [functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.), pygm.utils.inner_prod_aff_fn],
@@ -201,8 +267,49 @@ def test_ngm():
}, backends)
+def test_genn_astar():
+ # test pretrained by AIDS700nef
+ args1 = (list(range(10, 30, 2)), 36, pygm.genn_astar,{
+ "pretrain": ["AIDS700nef"],
+ "beam_width": [0, 1, 2],
+ "trust_fact": [0.9, 0.95, 1.0],
+ "no_pred_size": [0, 1],
+
+ }, backends)
+
+ # non-batched input
+ args2 = ([10], 36, pygm.genn_astar,{
+ 'pretrain': ["AIDS700nef"],
+ "beam_width": [0, 1, 2],
+ "trust_fact": [0.9, 0.95, 1.0],
+ "no_pred_size": [0, 1],
+ }, backends)
+
+ # test pretrained by LINUX
+ args3 = (list(range(10, 30, 2)), 8, pygm.genn_astar,{
+ 'pretrain': ['LINUX'],
+ "beam_width": [0, 1, 2],
+ "trust_fact": [0.9, 0.95, 1.0],
+ "no_pred_size": [0, 1],
+ }, backends)
+
+ # non-batched input
+ args4 = ([10], 8, pygm.genn_astar,{
+ 'pretrain': ['LINUX'],
+ "beam_width": [0, 1, 2],
+ "trust_fact": [0.9, 0.95, 1.0],
+ "no_pred_size": [0, 1],
+ }, backends)
+
+ _test_genn_astar(*args1)
+ _test_genn_astar(*args2)
+ _test_genn_astar(*args3)
+ _test_genn_astar(*args4)
+
+
if __name__ == '__main__':
test_pca_gm()
test_ipca_gm()
test_cie()
test_ngm()
+ test_genn_astar()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index c0c123e7..25cc4434 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -29,3 +29,4 @@ def data_to_numpy(*data):
return return_list
else:
return return_list[0]
+
\ No newline at end of file