diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d817d00d..5b94bbe4 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -30,6 +30,9 @@ jobs: python -m pip install flake8 pytest-cov if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi if [ "${{ matrix.python-version }}" != "3.10" ]; then pip install mindspore==1.10.0; fi + - name: preparation for tests + run: | + python tests/test_a_star/prepare_for_test.py - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -64,6 +67,9 @@ jobs: python -m pip install --upgrade pip python -m pip install flake8 pytest-cov pip install -r tests/requirements_win_mac.txt + - name: preparation for tests + run: | + python tests/test_a_star/prepare_for_test.py - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -74,7 +80,6 @@ jobs: run: | pytest --cov=pygmtools --cov-report=xml --backend=mindspore tests/test_classic_solvers.py pytest --cov=pygmtools --cov-report=xml --cov-append - windows: runs-on: windows-latest @@ -94,6 +99,9 @@ jobs: python -m pip install --upgrade pip python -m pip install flake8 pytest-cov python -m pip install -r tests\requirements_win_mac.txt + - name: preparation for tests + run: | + python tests/test_a_star/prepare_for_test.py - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -103,4 +111,4 @@ jobs: - name: Test with pytest. They are divided into two runs because MindSpore will interfer with Paddle. run: | pytest --cov=pygmtools --cov-report=xml --backend=mindspore tests/test_classic_solvers.py - pytest --cov=pygmtools --cov-report=xml --cov-append + pytest --cov=pygmtools --cov-report=xml --cov-append \ No newline at end of file diff --git a/docs/images/astar.png b/docs/images/astar.png new file mode 100644 index 00000000..b21c5a64 Binary files /dev/null and b/docs/images/astar.png differ diff --git a/docs/requirements.txt b/docs/requirements.txt index 5c30f6f0..cdf54f14 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,4 +15,5 @@ matplotlib networkx scikit-learn wget +networkx==2.8.8 pygmtools diff --git a/examples/data/graph1.graphml b/examples/data/graph1.graphml new file mode 100644 index 00000000..fa9bdbbb --- /dev/null +++ b/examples/data/graph1.graphml @@ -0,0 +1,61 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/data/graph2.graphml b/examples/data/graph2.graphml new file mode 100644 index 00000000..228ec85a --- /dev/null +++ b/examples/data/graph2.graphml @@ -0,0 +1,67 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/pygmtools/__init__.py b/pygmtools/__init__.py index 6899fbfc..b2b3534a 100644 --- a/pygmtools/__init__.py +++ b/pygmtools/__init__.py @@ -10,9 +10,9 @@ from .benchmark import Benchmark from .linear_solvers import sinkhorn, hungarian -from .classic_solvers import rrwm, sm, ipfp +from .classic_solvers import rrwm, sm, ipfp, astar from .multi_graph_solvers import cao, mgm_floyd, gamgm -from .neural_solvers import pca_gm, ipca_gm, cie, ngm +from .neural_solvers import pca_gm, ipca_gm, cie, ngm, genn_astar import pygmtools.utils as utils BACKEND = 'numpy' __version__ = '0.3.8' diff --git a/pygmtools/astar/a_star.tar.gz b/pygmtools/astar/a_star.tar.gz new file mode 100644 index 00000000..a63c1c76 Binary files /dev/null and b/pygmtools/astar/a_star.tar.gz differ diff --git a/pygmtools/classic_solvers.py b/pygmtools/classic_solvers.py index 914a514b..d42fc9e9 100644 --- a/pygmtools/classic_solvers.py +++ b/pygmtools/classic_solvers.py @@ -22,8 +22,9 @@ import importlib import pygmtools -from pygmtools.utils import NOT_IMPLEMENTED_MSG, _check_shape, _get_shape, _unsqueeze, _squeeze, _check_data_type - +from pygmtools.utils import NOT_IMPLEMENTED_MSG, _check_shape, _get_shape,\ + _unsqueeze, _squeeze, _check_data_type,from_numpy +import numpy as np def sm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None, max_iter: int=50, @@ -1061,6 +1062,117 @@ def ipfp(K, n1=None, n2=None, n1max=None, n2max=None, x0=None, return result +def astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, dropout=0, beam_width=0, + trust_fact=1, no_pred_size=0, backend=None): + r""" + ASTAR solver for graph matching (Lawler's QAP). + The **ASTAR** solver finds the optimal match between two graphs through heuristic search. + + :param feat1: :math:`(b\times n_1 \times d)` input feature of graph1 + :param feat2: :math:`(b\times n_2 \times d)` input feature of graph2 + :param A1: :math:`(b\times n_1 \times n_1)` input adjacency matrix of graph1 + :param A2: :math:`(b\times n_2 \times n_2)` input adjacency matrix of graph2 + :param n1: :math:`(b)` number of nodes in graph1. Optional if all equal to :math:`n_1` + :param n2: :math:`(b)` number of nodes in graph2. Optional if all equal to :math:`n_2` + :param channel: (default: None) Channel size of the input layer. If given, it must match the feature dimension (d) of feat1, feat2. + If not given, it will be defined by the feature dimension (d) of feat1, feat2. + Ignored if the network object isgiven (ignored if network!=None) + :param dropout: (default: 0) Dropout probability + :param beam_width: (default: 0) Size of beam-search witdh (0 = no beam). + :param trust_fact: (default: 1) The trust factor on GNN prediction (0 = no GNN). + :param no_pred_size: (default: 0) If the smaller graph has no more than x nodes, stop using heuristics. + :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation. + :return: :math:`(b\times n_1 \times n_2)` the doubly-stochastic matching matrix + + .. note:: + This function also supports non-batched input, by ignoring all batch dimensions in the input tensors. + + .. dropdown:: PyTorch Example + + :: + + >>> import torch + >>> import pygmtools as pygm + >>> pygm.BACKEND = 'pytorch' + >>> _ = torch.manual_seed(1) + + # Generate a batch of isomorphic graphs + >>> batch_size = 10 + >>> nodes_num = 4 + >>> channel = 36 + + >>> X_gt = torch.zeros(batch_size, nodes_num, nodes_num) + >>> X_gt[:, torch.arange(0, nodes_num, dtype=torch.int64), torch.randperm(nodes_num)] = 1 + >>> A1 = 1. * (torch.rand(batch_size, nodes_num, nodes_num) > 0.5) + >>> torch.diagonal(A1, dim1=1, dim2=2)[:] = 0 # discard self-loop edges + >>> A2 = torch.bmm(torch.bmm(X_gt.transpose(1, 2), A1), X_gt) + >>> feat1 = torch.rand(batch_size, nodes_num, channel) - 0.5 + >>> feat2 = torch.bmm(X_gt.transpose(1, 2), feat1) + >>> n1 = n2 = torch.tensor([nodes_num] * batch_size) + + # Match by ASTAR (load pretrained model) + >>> X = pygm.astar(feat1, feat2, A1, A2, n1, n2) + >>> (X * X_gt).sum() / X_gt.sum()# accuracy + tensor(1.) + + # This function also supports non-batched input, by ignoring all batch dimensions in the input tensors. + >>> part_f1 = feat1[0] + >>> part_f2 = feat2[0] + >>> part_A1 = A1[0] + >>> part_A2 = A2[0] + >>> part_X_gt = X_gt[0] + >>> part_X = pygm.astar(part_f1, part_f2, part_A1, part_A2) + + >>> part_X.shape + torch.Size([4, 4]) + + >>> (part_X * part_X_gt).sum() / part_X_gt.sum()# accuracy + tensor(1.) + """ + if backend is None: + backend = pygmtools.BACKEND + non_batched_input = False + if feat1 is not None: # if feat1 is None, this function skips the forward pass and only returns a network object + for _ in (feat1, feat2, A1, A2): + _check_data_type(_, backend) + + if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]): + feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)] + if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend) + if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend) + non_batched_input = True + elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]): + non_batched_input = False + else: + raise ValueError( + f'the input arguments feat1, feat2, A1, A2 are expected to be all 2-dimensional or 3-dimensional, got ' + f'feat1:{len(_get_shape(feat1, backend))}dims, feat2:{len(_get_shape(feat2, backend))}dims, ' + f'A1:{len(_get_shape(A1, backend))}dims, A2:{len(_get_shape(A2, backend))}dims!') + + if not (_get_shape(feat1, backend)[0] == _get_shape(feat2, backend)[0] == _get_shape(A1, backend)[0] == _get_shape(A2, backend)[0])\ + or not (_get_shape(feat1, backend)[1] == _get_shape(A1, backend)[1] == _get_shape(A1, backend)[2])\ + or not (_get_shape(feat2, backend)[1] == _get_shape(A2, backend)[1] == _get_shape(A2, backend)[2])\ + or not (_get_shape(feat1, backend)[2] == _get_shape(feat2, backend)[2]): + raise ValueError( + f'the input dimensions do not match. Got feat1:{_get_shape(feat1, backend)}, ' + f'feat2:{_get_shape(feat2, backend)}, A1:{_get_shape(A1, backend)}, A2:{_get_shape(A2, backend)}!') + if n1 is not None: _check_data_type(n1, 'n1', backend) + if n2 is not None: _check_data_type(n2, 'n2', backend) + + args = (feat1, feat2, A1, A2, n1, n2, channel, dropout, beam_width, trust_fact, no_pred_size) + try: + mod = importlib.import_module(f'pygmtools.{backend}_backend') + fn = mod.astar + except (ModuleNotFoundError, AttributeError): + raise NotImplementedError( + NOT_IMPLEMENTED_MSG.format(backend) + ) + + result = fn(*args) + match_mat = _squeeze(result[0], 0, backend) if non_batched_input else result[0] + return match_mat + + def __check_gm_arguments(n1, n2, n1max, n2max): if n1 is None and n1max is None: raise ValueError('at least one of the following arguments are required: n1 and n1max.') diff --git a/pygmtools/neural_solvers.py b/pygmtools/neural_solvers.py index fa40bdad..5be71048 100644 --- a/pygmtools/neural_solvers.py +++ b/pygmtools/neural_solvers.py @@ -1271,3 +1271,211 @@ def ngm(K, n1=None, n2=None, n1max=None, n2max=None, x0=None, return match_mat, result[1] else: return match_mat + + +def genn_astar(feat1, feat2, A1, A2, n1=None, n2=None, channel=None, filters_1=64, filters_2=32, filters_3=16, + tensor_neurons=16, dropout=0, beam_width=0, trust_fact=1, no_pred_size=0, + network=None, return_network=False, pretrain='AIDS700nef', backend=None): + r""" + The **GENN-ASTAR** (Graph Edit Neural Network Astar) solver for graph matching based on the combination of traditional A-star and Neural Network. + This algorithm replaces the heuristic prediction module in the traditional A-star algorithm with **GNN** (Graph Neural Network) model, + greatly improving the efficiency of A-star algorithm while ensuring a certain degree of accuracy. + During the search process, the algorithm prioritizes the next search direction based on the distance between the current state and the target state. + At each step of the search, the algorithm uses a predicted probability distribution of node pairs for matching. + + See the following picture to better understand the workflow of the algorithm: + + .. image:: ../../images/astar.png + + See the following paper for more technical details: + `"Combinatorial Learning of Graph Edit Distance via Dynamic Embedding" + `_ + + + :param feat1: :math:`(b\times n_1 \times d)` input feature of graph1 + :param feat2: :math:`(b\times n_2 \times d)` input feature of graph2 + :param A1: :math:`(b\times n_1 \times n_1)` input adjacency matrix of graph1 + :param A2: :math:`(b\times n_2 \times n_2)` input adjacency matrix of graph2 + :param n1: :math:`(b)` number of nodes in graph1. Optional if all equal to :math:`n_1` + :param n2: :math:`(b)` number of nodes in graph2. Optional if all equal to :math:`n_2` + :param channel: (default: None) Channel size of the input layer. If given, it must match the feature dimension (d) of feat1, feat2. + If not given, it will be defined by the feature dimension (d) of feat1, feat2. + Ignored if the network object isgiven (ignored if network!=None) + :param filters_1: (default: 64) Filters (neurons) in 1st convolution. + :param filters_2: (default: 32) Filters (neurons) in 2nd convolution. + :param filters_3: (default: 16) Filters (neurons) in 2nd convolution. + :param tensor_neurons: (default: 16) Neurons in tensor network layer. + :param dropout: (default: 0) Dropout probability + :param beam_width: (default: 0) Size of beam-search witdh (0 = no beam). + :param trust_fact: (default: 1) The trust factor on GNN prediction (0 = no GNN). + :param no_pred_size: (default: 0) If the smaller graph has no more than x nodes, stop using heuristics. + :param network: (default: None) The network object. If None, a new network object will be created, and load the + model weights specified in ``pretrain`` argument. + :param return_network: (default: False) Return the network object (saving model construction time if calling the + model multiple times). + :param pretrain: (default: 'AIDS700nef') If ``network==None``, the pretrained model weights to be loaded. Available + pretrained weights: ``AIDS700nef`` (channel=36), ``LINUX`` (channel=8), + or ``False`` (no pretraining). + :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation. + :return: if ``return_network==False``, :math:`(b\times n_1 \times n_2)` the doubly-stochastic matching matrix + + if ``return_network==True``, :math:`(b\times n_1 \times n_2)` the doubly-stochastic matching matrix, + the network object + + .. note:: + You may need a proxy to load the pretrained weights if Google drive is not accessible in your contry/region. + You may also download the pretrained models manually and put them at ``~/.cache/pygmtools`` (for Linux). + + `[google drive] `_ + + .. note:: + This function also supports non-batched input, by ignoring all batch dimensions in the input tensors. + + .. dropdown:: PyTorch Example + + :: + + >>> import torch + >>> import pygmtools as pygm + >>> pygm.BACKEND = 'pytorch' + >>> _ = torch.manual_seed(1) + + # Generate a batch of isomorphic graphs + >>> batch_size = 10 + >>> nodes_num = 4 + >>> channel = 36 + + >>> X_gt = torch.zeros(batch_size, nodes_num, nodes_num) + >>> X_gt[:, torch.arange(0, nodes_num, dtype=torch.int64), torch.randperm(nodes_num)] = 1 + >>> A1 = 1. * (torch.rand(batch_size, nodes_num, nodes_num) > 0.5) + >>> torch.diagonal(A1, dim1=1, dim2=2)[:] = 0 # discard self-loop edges + >>> A2 = torch.bmm(torch.bmm(X_gt.transpose(1, 2), A1), X_gt) + >>> feat1 = torch.rand(batch_size, nodes_num, channel) - 0.5 + >>> feat2 = torch.bmm(X_gt.transpose(1, 2), feat1) + >>> n1 = n2 = torch.tensor([nodes_num] * batch_size) + + # Match by GENN-ASTAR (load pretrained model) + >>> X, net = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, return_network=True) + Downloading to ~/.cache/pygmtools/best_genn_AIDS700nef_gcn_astar.pt... + >>> (X * X_gt).sum() / X_gt.sum()# accuracy + tensor(1.) + + # Pass the net object to avoid rebuilding the model agian + >>> X = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, network=net) + + # This function also supports non-batched input, by ignoring all batch dimensions in the input tensors. + >>> part_f1 = feat1[0] + >>> part_f2 = feat2[0] + >>> part_A1 = A1[0] + >>> part_A2 = A2[0] + >>> part_X_gt = X_gt[0] + >>> part_X = pygm.genn_astar(part_f1, part_f2, part_A1, part_A2, return_network=False) + + >>> part_X.shape + torch.Size([4, 4]) + + >>> (part_X * part_X_gt).sum() / part_X_gt.sum()# accuracy + tensor(1.) + + # You may also load other pretrained weights + # However, it should be noted that each pretrained set supports different node feature dimensions + # AIDS700nef(Default): channel = 36 + # LINUX: channel = 8 + # Generate a batch of isomorphic graphs + >>> batch_size = 10 + >>> nodes_num = 4 + >>> channel = 8 + + >>> X_gt = torch.zeros(batch_size, nodes_num, nodes_num) + >>> X_gt[:, torch.arange(0, nodes_num, dtype=torch.int64), torch.randperm(nodes_num)] = 1 + >>> A1 = 1. * (torch.rand(batch_size, nodes_num, nodes_num) > 0.5) + >>> torch.diagonal(A1, dim1=1, dim2=2)[:] = 0 # discard self-loop edges + >>> A2 = torch.bmm(torch.bmm(X_gt.transpose(1, 2), A1), X_gt) + >>> feat1 = torch.rand(batch_size, nodes_num, channel) - 0.5 + >>> feat2 = torch.bmm(X_gt.transpose(1, 2), feat1) + >>> n1 = n2 = torch.tensor([nodes_num] * batch_size) + + >>> X, net = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, pretrain='LINUX', return_network=True) + Downloading to ~/.cache/pygmtools/best_genn_LINUX_gcn_astar.pt... + + >>> (X * X_gt).sum() / X_gt.sum()# accuracy + tensor(1.) + + # When the input node feature dimension is different from the one supported by pre training, + # you can still use the solver, but the solver will provide a warning + >>> X, net = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, return_network=True, pretrain='AIDS700nef') + Warning: Skip key(s) in state_dict: "convolution_1.weight". + Warning: Pretrain AIDS700nef does not support the parameters you entered, + Supported parameters: ( channel:36, filters:(64,32,16) ), + Input parameters: ( channel:8, filters:(64,32,16) ) + + # You may configure your own model and integrate the model into a deep learning pipeline. For example: + >>> net = pygm.utils.get_network(pygm.genn_astar, channel = 1000, filters_1 = 1024, filters_2 = 256, filters_3 = 128, pretrain=False) + >>> optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + # feat1/feat2 may be outputs by other neural networks + >>> X = pygm.genn_astar(feat1, feat2, A1, A2, n1, n2, network=net) + >>> loss = pygm.utils.permutation_loss(X, X_gt) + >>> loss.backward() + >>> optimizer.step() + + .. note:: + + If you find this model useful in your research, please cite: + + :: + + @inproceedings{WangCVPR21, + author={Runzhong Wang, Tianqi Zhang, Tianshu Yu, Junchi Yan, Xiaokang Yang}, + booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + title={Combinatorial Learning of Graph Edit Distance via Dynamic Embedding}, + year={2021}, + } + """ + + if backend is None: + backend = pygmtools.BACKEND + non_batched_input = False + if feat1 is not None: # if feat1 is None, this function skips the forward pass and only returns a network object + for _ in (feat1, feat2, A1, A2): + _check_data_type(_, backend) + + if all([_check_shape(_, 2, backend) for _ in (feat1, feat2, A1, A2)]): + feat1, feat2, A1, A2 = [_unsqueeze(_, 0, backend) for _ in (feat1, feat2, A1, A2)] + if type(n1) is int: n1 = from_numpy(np.array([n1]), backend=backend) + if type(n2) is int: n2 = from_numpy(np.array([n2]), backend=backend) + non_batched_input = True + elif all([_check_shape(_, 3, backend) for _ in (feat1, feat2, A1, A2)]): + non_batched_input = False + else: + raise ValueError( + f'the input arguments feat1, feat2, A1, A2 are expected to be all 2-dimensional or 3-dimensional, got ' + f'feat1:{len(_get_shape(feat1, backend))}dims, feat2:{len(_get_shape(feat2, backend))}dims, ' + f'A1:{len(_get_shape(A1, backend))}dims, A2:{len(_get_shape(A2, backend))}dims!') + + if not (_get_shape(feat1, backend)[0] == _get_shape(feat2, backend)[0] == _get_shape(A1, backend)[0] == _get_shape(A2, backend)[0])\ + or not (_get_shape(feat1, backend)[1] == _get_shape(A1, backend)[1] == _get_shape(A1, backend)[2])\ + or not (_get_shape(feat2, backend)[1] == _get_shape(A2, backend)[1] == _get_shape(A2, backend)[2])\ + or not (_get_shape(feat1, backend)[2] == _get_shape(feat2, backend)[2]): + raise ValueError( + f'the input dimensions do not match. Got feat1:{_get_shape(feat1, backend)}, ' + f'feat2:{_get_shape(feat2, backend)}, A1:{_get_shape(A1, backend)}, A2:{_get_shape(A2, backend)}!') + if n1 is not None: _check_data_type(n1, 'n1', backend) + if n2 is not None: _check_data_type(n2, 'n2', backend) + + args = (feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3, + tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain) + try: + mod = importlib.import_module(f'pygmtools.{backend}_backend') + fn = mod.genn_astar + except (ModuleNotFoundError, AttributeError): + raise NotImplementedError( + NOT_IMPLEMENTED_MSG.format(backend) + ) + + result = fn(*args) + match_mat = _squeeze(result[0], 0, backend) if non_batched_input else result[0] + if return_network: + return match_mat, result[1] + else: + return match_mat + \ No newline at end of file diff --git a/pygmtools/pytorch_astar_modules.py b/pygmtools/pytorch_astar_modules.py new file mode 100644 index 00000000..1c2c8540 --- /dev/null +++ b/pygmtools/pytorch_astar_modules.py @@ -0,0 +1,381 @@ +import torch +import pygmtools.utils +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from torch import Tensor +from typing import Optional, Tuple + +VERY_LARGE_INT = 65536 + +############################################################### +# GENN-A* Functions # +############################################################### + + +def default_parameter(): + params = dict() + params['cuda'] = False + params['pretrain'] = False + params['channel'] = 36 + params['filters_1'] = 64 + params['filters_2'] = 32 + params['filters_3'] = 16 + params['tensor_neurons'] = 16 + params['dropout'] = 0 + params['astar_beam_width'] = 0 + params['astar_trust_fact'] = 1 + params['astar_no_pred'] = 0 + params['use_net'] = True + return params + + +def check_layer_parameter(params): + if params['pretrain'] == 'AIDS700nef': + if params['channel'] != 36: + return False + elif params['pretrain'] == 'LINUX': + if params['channel'] != 8: + return False + if params['filters_1'] != 64: + return False + if params['filters_2'] != 32: + return False + if params['filters_3'] != 16: + return False + if params['tensor_neurons'] != 16: + return False + return True + + +def node_metric(node1, node2): + + encoding = torch.sum(torch.abs(node1.unsqueeze(2) - node2.unsqueeze(1)), dim=-1) + non_zero = torch.nonzero(encoding) + for i in range(non_zero.shape[0]): + encoding[non_zero[i][0], non_zero[i][1], non_zero[i][2]] = 1 + return encoding + + +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand(other.size()) + return src + + +def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + index = broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode='floor') + return out + + +def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None, + fill_value: float = 0., max_num_nodes: Optional[int] = None, + batch_size: Optional[int] = None) -> Tuple[Tensor, Tensor]: + if batch is None and max_num_nodes is None: + mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device) + return x.unsqueeze(0), mask + + if batch is None: + batch = x.new_zeros(x.size(0), dtype=torch.long) + + if batch_size is None: + batch_size = int(batch.max()) + 1 + + num_nodes = scatter_sum(batch.new_ones(x.size(0)), batch, dim=0, + dim_size=batch_size) + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + if max_num_nodes is None: + max_num_nodes = int(num_nodes.max()) + + idx = torch.arange(batch.size(0), dtype=torch.long, device=x.device) + idx = (idx - cum_nodes[batch]) + (batch * max_num_nodes) + + size = [batch_size * max_num_nodes] + list(x.size())[1:] + out = x.new_full(size, fill_value) + out[idx] = x + out = out.view([batch_size, max_num_nodes] + list(x.size())[1:]) + + mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool, + device=x.device) + mask[idx] = 1 + mask = mask.view(batch_size, max_num_nodes) + + return out, mask + + +def to_dense_adj(edge_index: Tensor,batch=None,edge_attr=None,max_num_nodes: Optional[int] = None) -> Tensor: + if batch is None: + num_nodes = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 + batch = edge_index.new_zeros(num_nodes) + + batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1 + one = batch.new_ones(batch.size(0)) + num_nodes = scatter_sum(one, batch, dim=0, dim_size=batch_size) + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + idx0 = batch[edge_index[0]] + idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]] + idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]] + + if max_num_nodes is None: + max_num_nodes = num_nodes.max().item() + + elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes) + or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)): + mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes) + idx0 = idx0[mask] + idx1 = idx1[mask] + idx2 = idx2[mask] + edge_attr = None if edge_attr is None else edge_attr[mask] + + if edge_attr is None: + edge_attr = torch.ones(idx0.numel(), device=edge_index.device) + + size = [batch_size, max_num_nodes, max_num_nodes] + size += list(edge_attr.size())[1:] + adj = torch.zeros(size, dtype=edge_attr.dtype, device=edge_index.device) + + flattened_size = batch_size * max_num_nodes * max_num_nodes + adj = adj.view([flattened_size] + list(adj.size())[3:]) + idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2 + adj = scatter_sum(edge_attr, idx, dim=0, dim_size=flattened_size) + adj = adj.view(size) + + return adj + + +############################################################### +# GENN-A* Modules # +############################################################### + + +class GraphPair: + def __init__(self, x1: torch.Tensor, x2: torch.Tensor, adj1: torch.Tensor, + adj2: torch.Tensor, n1=None, n2=None): + self.g1 = Graphs(x1, adj1, n1) + self.g2 = Graphs(x2, adj2, n2) + + def __repr__(self): + return f"{self.__class__.__name__}('g1' = {self.g1}, 'g2' = {self.g2})" + + def to_dict(self): + data = dict() + data['g1'] = self.g1 + data['g2'] = self.g2 + return data + + +class Graphs: + def __init__(self, x: torch.Tensor, adj: torch.Tensor, nodes_num=None): + assert len(x.shape) == len(adj.shape) + if len(adj.shape) == 2: + adj = adj.unsqueeze(dim=0) + x = x.unsqueeze(dim=0) + assert x.shape[0] == adj.shape[0] + assert x.shape[1] == adj.shape[1] + self.x = x + self.adj = adj + self.num_graphs = adj.shape[0] + if nodes_num is not None: + self.nodes_num = nodes_num + else: + self.nodes_num = torch.tensor([x.shape[1]]*x.shape[0]) + if self.x.shape[0] == 1: + if self.x.shape[1] != nodes_num: + self.x = self.x[:, :nodes_num, :] + self.adj = self.adj[:, :nodes_num, :nodes_num] + self.edge_index = None + self.edge_weight = None + self.batch = None + self.graph_process() + + def graph_process(self): + edge_index, edge_weight, _ = pygmtools.utils.dense_to_sparse(self.adj) + self.edge_index = torch.cat([edge_index[:, :, 0].unsqueeze(dim=1), + edge_index[:, :, 1].unsqueeze(dim=1)], dim=1) + self.edge_weight = edge_weight.view(-1) + if self.nodes_num.shape == torch.Size([]): + batch = torch.tensor([0] * self.nodes_num) + else: + for i in range(len(self.nodes_num)): + if i == 0: + batch = torch.tensor([i] * self.nodes_num[i]) + else: + cur_batch = torch.tensor([i] * self.nodes_num[i]) + batch = torch.cat([batch, cur_batch]) + self.batch = batch + + def __repr__(self): + message = "x = {}, adj = {}".format(list(self.x.shape), list(self.adj.shape)) + message += " edge_index = {}, edge_weight = {}".format(list(self.edge_index.shape), list(self.edge_weight.shape)) + message += " nodes_num = {}, num_graphs = {})".format(self.nodes_num.shape, self.num_graphs) + self.message = message + return f"{self.__class__.__name__}({self.message})" + + +class AttentionModule(torch.nn.Module): + """ + SimGNN Attention Module to make a pass on graph. + """ + def __init__(self, args): + """ + :param args: Arguments object. + """ + super(AttentionModule, self).__init__() + self.args = args + self.setup_weights() + self.init_parameters() + + def setup_weights(self): + """ + Defining weights. + """ + self.weight_matrix = torch.nn.Parameter(torch.Tensor(self.args['filters_3'], self.args['filters_3'])) + + def init_parameters(self): + """ + Initializing weights. + """ + torch.nn.init.xavier_uniform_(self.weight_matrix) + + def forward(self, x, batch, size=None): + """ + Making a forward propagation pass to create a graph level representation. + :param x: Result of the GNN. + :param batch: Batch vector, which assigns each node to a specific example + :return representation: A graph level representation matrix. + """ + size = batch[-1].item() + 1 if size is None else size + mean = scatter_mean(x, batch, dim=0, dim_size=size) + transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix)) + + coefs = torch.sigmoid((x * transformed_global[batch] * 10).sum(dim=1)) + weighted = coefs.unsqueeze(-1) * x + + return scatter_sum(weighted, batch, dim=0, dim_size=size) + + def get_coefs(self, x): + mean = x.mean(dim=0) + transformed_global = torch.tanh(torch.matmul(mean, self.weight_matrix)) + + return torch.sigmoid(torch.matmul(x, transformed_global)) + + +class TensorNetworkModule(torch.nn.Module): + """ + SimGNN Tensor Network module to calculate similarity vector. + """ + def __init__(self,args): + """ + :param args: Arguments object. + """ + super(TensorNetworkModule, self).__init__() + self.args = args + self.setup_weights() + self.init_parameters() + + def setup_weights(self): + """ + Defining weights. + """ + self.weight_matrix = torch.nn.Parameter(torch.Tensor(self.args['filters_3'], self.args['filters_3'], self.args['tensor_neurons'])) + self.weight_matrix_block = torch.nn.Parameter(torch.Tensor(self.args['tensor_neurons'], 2*self.args['filters_3'])) + self.bias = torch.nn.Parameter(torch.Tensor(self.args['tensor_neurons'], 1)) + + def init_parameters(self): + """ + Initializing weights. + """ + torch.nn.init.xavier_uniform_(self.weight_matrix) + torch.nn.init.xavier_uniform_(self.weight_matrix_block) + torch.nn.init.xavier_uniform_(self.bias) + + def forward(self, embedding_1, embedding_2): + """ + Making a forward propagation pass to create a similarity vector. + :param embedding_1: Result of the 1st embedding after attention. + :param embedding_2: Result of the 2nd embedding after attention. + :return scores: A similarity score vector. + """ + batch_size = len(embedding_1) + scoring = torch.matmul(embedding_1, self.weight_matrix.view(self.args['filters_3'],-1)) + scoring = scoring.view(batch_size, self.args['filters_3'], -1).permute([0, 2, 1]) + scoring = torch.matmul(scoring, embedding_2.view(batch_size, self.args['filters_3'], 1)).view(batch_size, -1) + combined_representation = torch.cat((embedding_1, embedding_2), 1) + block_scoring = torch.t(torch.mm(self.weight_matrix_block, torch.t(combined_representation))) + scores = F.relu(scoring + block_scoring + self.bias.view(-1)) + return scores + + +class GCNConv(nn.Module): + + def __init__(self, in_features: int, out_features: int): + super(GCNConv, self).__init__() + self.num_inputs = in_features + self.num_outputs = out_features + self.weight = Parameter(torch.empty((in_features,out_features))) + self.bias = Parameter(torch.empty(out_features)) + + def forward(self, A: Tensor, x: Tensor, norm: bool=True) -> Tensor: + r""" + Forward computation of graph convolution network. + + :param A: :math:`(b\times n\times n)` {0,1} adjacency matrix. :math:`b`: batch size, :math:`n`: number of nodes + :param x: :math:`(b\times n\times d)` input node embedding. :math:`d`: feature dimension + :param norm: normalize connectivity matrix or not + :return: :math:`(b\times n\times d^\prime)` new node embedding + """ + x = torch.mm(x,self.weight) + self.bias + D = torch.zeros_like(A) + + for i in range(A.shape[0]): + A[i,i] = 1 + D[i,i] = torch.pow(torch.sum(A[i]),exponent=-0.5) + A = torch.mm(torch.mm(D,A),D) + return torch.mm(A,x) + + def __repr__(self): + return f"{self.__class__.__name__}(in_features={self.num_inputs}, out_features={self.num_outputs})" diff --git a/pygmtools/pytorch_backend.py b/pygmtools/pytorch_backend.py index 2d3c9eca..b43a52f2 100644 --- a/pygmtools/pytorch_backend.py +++ b/pygmtools/pytorch_backend.py @@ -17,6 +17,10 @@ import os import pygmtools.utils +from .pytorch_astar_modules import GCNConv, AttentionModule, TensorNetworkModule, GraphPair, \ + VERY_LARGE_INT, to_dense_adj, to_dense_batch, default_parameter, check_layer_parameter, node_metric +from torch import Tensor +from pygmtools.a_star import a_star ############################################# # Linear Assignment Problem Solvers # @@ -25,9 +29,9 @@ from pygmtools.numpy_backend import _hung_kernel -def hungarian(s: Tensor, n1: Tensor=None, n2: Tensor=None, - unmatch1: Tensor=None, unmatch2: Tensor=None, - nproc: int=1) -> Tensor: +def hungarian(s: Tensor, n1: Tensor = None, n2: Tensor = None, + unmatch1: Tensor = None, unmatch2: Tensor = None, + nproc: int = 1) -> Tensor: """ Pytorch implementation of Hungarian algorithm """ @@ -57,16 +61,17 @@ def hungarian(s: Tensor, n1: Tensor=None, n2: Tensor=None, mapresult = pool.starmap_async(_hung_kernel, zip(perm_mat, n1, n2, unmatch1, unmatch2)) perm_mat = np.stack(mapresult.get()) else: - perm_mat = np.stack([_hung_kernel(perm_mat[b], n1[b], n2[b], unmatch1[b], unmatch2[b]) for b in range(batch_num)]) + perm_mat = np.stack( + [_hung_kernel(perm_mat[b], n1[b], n2[b], unmatch1[b], unmatch2[b]) for b in range(batch_num)]) perm_mat = torch.from_numpy(perm_mat).to(device) return perm_mat -def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None, - unmatchrows: Tensor=None, unmatchcols: Tensor=None, - dummy_row: bool=False, max_iter: int=10, tau: float=1., batched_operation: bool=False) -> Tensor: +def sinkhorn(s: Tensor, nrows: Tensor = None, ncols: Tensor = None, + unmatchrows: Tensor = None, unmatchcols: Tensor = None, + dummy_row: bool = False, max_iter: int = 10, tau: float = 1., batched_operation: bool = False) -> Tensor: """ Pytorch implementation of Sinkhorn algorithm """ @@ -91,7 +96,7 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None, s_t = s.transpose(1, 2) s_t = torch.cat(( s_t[:, :s.shape[1], :], - torch.full((batch_size, s.shape[1], s.shape[2]-s.shape[1]), -float('inf'), device=s.device)), dim=2) + torch.full((batch_size, s.shape[1], s.shape[2] - s.shape[1]), -float('inf'), device=s.device)), dim=2) s = torch.where(transposed_batch.view(batch_size, 1, 1), s_t, s) new_nrows = torch.where(transposed_batch, ncols, nrows) @@ -103,8 +108,9 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None, unmatchrows_pad = torch.cat(( unmatchrows, torch.full((batch_size, unmatchcols.shape[1] - unmatchrows.shape[1]), -float('inf'), device=s.device)), - dim=1) - new_unmatchrows = torch.where(transposed_batch.view(batch_size, 1), unmatchcols, unmatchrows_pad)[:, :unmatchrows.shape[1]] + dim=1) + new_unmatchrows = torch.where(transposed_batch.view(batch_size, 1), unmatchcols, unmatchrows_pad)[:, + :unmatchrows.shape[1]] new_unmatchcols = torch.where(transposed_batch.view(batch_size, 1), unmatchrows_pad, unmatchcols) unmatchrows = new_unmatchrows unmatchcols = new_unmatchcols @@ -121,15 +127,19 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None, dummy_shape[1] = log_s.shape[2] - log_s.shape[1] ori_nrows = nrows nrows = ncols.clone() - log_s = torch.cat((log_s, torch.full(dummy_shape, -float('inf'), device=log_s.device, dtype=log_s.dtype)), dim=1) + log_s = torch.cat((log_s, torch.full(dummy_shape, -float('inf'), device=log_s.device, dtype=log_s.dtype)), + dim=1) if unmatchrows is not None: - unmatchrows = torch.cat((unmatchrows, torch.full((dummy_shape[0], dummy_shape[1]), -float('inf'), device=log_s.device, dtype=log_s.dtype)), dim=1) + unmatchrows = torch.cat((unmatchrows, + torch.full((dummy_shape[0], dummy_shape[1]), -float('inf'), device=log_s.device, + dtype=log_s.dtype)), dim=1) for b in range(batch_size): log_s[b, ori_nrows[b]:nrows[b], :ncols[b]] = -100 # assign the unmatch weights if unmatchrows is not None and unmatchcols is not None: - new_log_s = torch.full((log_s.shape[0], log_s.shape[1]+1, log_s.shape[2]+1), -float('inf'), device=log_s.device, dtype=log_s.dtype) + new_log_s = torch.full((log_s.shape[0], log_s.shape[1] + 1, log_s.shape[2] + 1), -float('inf'), + device=log_s.device, dtype=log_s.dtype) new_log_s[:, :-1, :-1] = log_s log_s = new_log_s for b in range(batch_size): @@ -161,7 +171,8 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None, ret_log_s = log_s else: - ret_log_s = torch.full((batch_size, log_s.shape[1], log_s.shape[2]), -float('inf'), device=log_s.device, dtype=log_s.dtype) + ret_log_s = torch.full((batch_size, log_s.shape[1], log_s.shape[2]), -float('inf'), device=log_s.device, + dtype=log_s.dtype) for b in range(batch_size): row_slice = slice(0, nrows[b]) @@ -198,7 +209,8 @@ def sinkhorn(s: Tensor, nrows: Tensor=None, ncols: Tensor=None, s_t = ret_log_s.transpose(1, 2) s_t = torch.cat(( s_t[:, :ret_log_s.shape[1], :], - torch.full((batch_size, ret_log_s.shape[1], ret_log_s.shape[2]-ret_log_s.shape[1]), -float('inf'), device=log_s.device)), dim=2) + torch.full((batch_size, ret_log_s.shape[1], ret_log_s.shape[2] - ret_log_s.shape[1]), -float('inf'), + device=log_s.device)), dim=2) ret_log_s = torch.where(transposed_batch.view(batch_size, 1, 1), s_t, ret_log_s) if transposed: @@ -410,7 +422,7 @@ def _comp_aff_score(x, k): X1 = X.reshape(m, 1, m, n, n).repeat(1, m, 1, 1, 1).reshape(-1, n, n) # X1[i,j,k] = X[i,k] X2 = X.reshape(1, m, m, n, n).repeat(m, 1, 1, 1, 1).transpose(1, 2).reshape(-1, n, n) # X2[i,j,k] = X[k,j] - X_combo = torch.bmm(X1, X2).reshape(m, m, m, n, n) # X_combo[i,j,k] = X[i, k] * X[k, j] + X_combo = torch.bmm(X1, X2).reshape(m, m, m, n, n) # X_combo[i,j,k] = X[i, k] * X[k, j] aff_ori = (_comp_aff_score(X.reshape(-1, n, n), K.reshape(-1, n * n, n * n)) / norm).reshape(m, m) pair_con = _get_batch_pc_opt(X) @@ -433,7 +445,8 @@ def _comp_aff_score(x, k): assert torch.all(score_combo + 1e-4 >= score_ori), torch.min(score_combo - score_ori) X_upt = X_combo[mask1, mask2, idx, :, :] - X = X_upt * X_mask + X_upt.transpose(0, 1).transpose(2, 3) * X_mask.transpose(0, 1) + X * (1 - X_mask - X_mask.transpose(0, 1)) + X = X_upt * X_mask + X_upt.transpose(0, 1).transpose(2, 3) * X_mask.transpose(0, 1) + X * ( + 1 - X_mask - X_mask.transpose(0, 1)) assert torch.all(X.transpose(0, 1).transpose(2, 3) == X) return X @@ -597,7 +610,7 @@ def gamgm( sk_iter, max_iter, quad_weight, converge_thresh, outlier_thresh, bb_smooth, verbose, - cluster_M=None, projector='sinkhorn', hung_iter=True # these arguments are reserved for clustering + cluster_M=None, projector='sinkhorn', hung_iter=True # these arguments are reserved for clustering ): """ Pytorch implementation of Graduated Assignment for Multi-Graph Matching (with compatibility for 2GM and clustering) @@ -661,6 +674,7 @@ class GAMGMTorchFunc(torch.autograd.Function): """ Torch wrapper to support forward and backward pass (by black-box differentiation) """ + @staticmethod def forward(ctx, bb_smooth, supA, supW, ns, n_indices, n_univ, num_graphs, U0, *args): # save parameters @@ -688,7 +702,8 @@ def backward(ctx, dU): end_x = n_indices[i] start_y = n_indices[j] - ns[j] end_y = n_indices[j] - supW[start_x:end_x, start_y:end_y] += bb_smooth * torch.mm(dU[start_x:end_x], dU[start_y:end_y].transpose(0, 1)) + supW[start_x:end_x, start_y:end_y] += bb_smooth * torch.mm(dU[start_x:end_x], + dU[start_y:end_y].transpose(0, 1)) U_prime = gamgm_real(supA, supW, ns, n_indices, n_univ, num_graphs, U0, *args) @@ -712,8 +727,8 @@ def gamgm_real( sk_iter, max_iter, quad_weight, converge_thresh, outlier_thresh, verbose, - cluster_M, projector, hung_iter # these arguments are reserved for clustering - ): + cluster_M, projector, hung_iter # these arguments are reserved for clustering +): """ The real forward function of GAMGM """ @@ -736,7 +751,7 @@ def gamgm_real( else: print_str = 'hungarian' print(print_str + f' #iter={i}/{max_iter} ' - f'quad score: {(quad * U).sum():.3e}, unary score: {(unary * U).sum():.3e}') + f'quad score: {(quad * U).sum():.3e}, unary score: {(unary * U).sum():.3e}') V = (quad + unary) / num_graphs U_list = [] @@ -813,7 +828,7 @@ def gamgm_real( if verbose: print('-' * 20) - if i == max_iter - 1: # not converged + if i == max_iter - 1: # not converged if hung_iter: pass else: @@ -837,6 +852,329 @@ def gamgm_real( return U +astar_pretrain_path = { + 'AIDS700nef': ('https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/pytorch_backend/best_genn_AIDS700nef_gcn_astar.pt', + 'b2516aea4c8d730704a48653a5ca94ba'), + 'LINUX': ('https://raw.githubusercontent.com/heatingma/pygmtools-pretrained-models/main/pytorch_backend/best_genn_LINUX_gcn_astar.pt', + 'fd3b2a8dfa3edb20607da2e2b96d2e96'), +} + + +class GENN(torch.nn.Module): + def __init__(self, args): + """ + :param args: Arguments object. + :param number_of_labels: Number of node labels. + """ + super(GENN, self).__init__() + self.args = args + if self.args['use_net']: + self.number_labels = self.args['channel'] + self.setup_layers() + + self.reset_cache() + + def reset_cache(self): + self.gnn_1_cache = dict() + self.gnn_2_cache = dict() + self.heuristic_cache = dict() + + def setup_layers(self): + """ + Creating the layers. + """ + self.feature_count = self.args['tensor_neurons'] + self.convolution_1 = GCNConv(self.number_labels, self.args['filters_1']) + self.convolution_2 = GCNConv(self.args['filters_1'], self.args['filters_2']) + self.convolution_3 = GCNConv(self.args['filters_2'], self.args['filters_3']) + self.attention = AttentionModule(self.args) + self.tensor_network = TensorNetworkModule(self.args) + self.scoring_layer = torch.nn.Sequential( + torch.nn.Linear(self.feature_count, 16), + torch.nn.ReLU(), + torch.nn.Linear(16, 1), + torch.nn.Sigmoid() + ) + + def convolutional_pass(self, edge_index, x, edge_weight=None): + """ + Making convolutional pass. + :param edge_index: Edge indices. + :param x: Feature matrix. + :param edge_weight: Edge weights. + :return features: Abstract feature matrix. + """ + + features = self.convolution_1(edge_index, x, edge_weight) + features = F.relu(features) + features = F.dropout(features, p=self.args['dropout'], training=self.training) + features = self.convolution_2(edge_index, features, edge_weight) + features = F.relu(features) + features = F.dropout(features, p=self.args['dropout'], training=self.training) + features = self.convolution_3(edge_index, features, edge_weight) + return features + + def forward(self, data: GraphPair): + """ + Forward pass with graphs. + :param data: Data dictionary. + :return score: Similarity score. + """ + num = data.g1.num_graphs + max_nodes_num_1 = torch.max(data.g1.nodes_num) + 1 + max_nodes_num_2 = torch.max(data.g2.nodes_num) + 1 + x_pred = torch.zeros(num, max_nodes_num_1, max_nodes_num_2) + for i in range(num): + cur_data = GraphPair(data.g1.x[i], data.g2.x[i], data.g1.adj[i], data.g2.adj[i], + data.g1.nodes_num[i], data.g2.nodes_num[i]) + num_nodes_1 = data.g1.nodes_num[i] + 1 + num_nodes_2 = data.g2.nodes_num[i] + 1 + x_pred[i][:num_nodes_1, :num_nodes_2] = self._a_star(cur_data) + return x_pred[:, :-1, :-1] + + def _a_star(self, data: GraphPair): + + if self.args['cuda']: + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = "cpu" + edge_index_1 = data.g1.edge_index.squeeze() + edge_index_2 = data.g2.edge_index.squeeze() + edge_attr_1 = data.g1.edge_weight + edge_attr_2 = data.g2.edge_weight + node_1 = data.g1.x.squeeze() + node_2 = data.g2.x.squeeze() + batch_1 = data.g1.batch + batch_2 = data.g2.batch + batch_num = data.g1.num_graphs + + ns_1 = torch.bincount(data.g1.batch) + ns_2 = torch.bincount(data.g2.batch) + + adj_1 = to_dense_adj(edge_index_1, batch=batch_1, edge_attr=edge_attr_1) + + dummy_adj_1 = torch.zeros(adj_1.shape[0], adj_1.shape[1] + 1, adj_1.shape[2] + 1, device=device) + dummy_adj_1[:, :-1, :-1] = adj_1 + adj_2 = to_dense_adj(edge_index_2, batch=batch_2, edge_attr=edge_attr_2) + dummy_adj_2 = torch.zeros(adj_2.shape[0], adj_2.shape[1] + 1, adj_2.shape[2] + 1, device=device) + dummy_adj_2[:, :-1, :-1] = adj_2 + + node_1, _ = to_dense_batch(node_1, batch=batch_1) + node_2, _ = to_dense_batch(node_2, batch=batch_2) + + dummy_node_1 = torch.zeros(adj_1.shape[0], node_1.shape[1] + 1, node_1.shape[-1], device=device) + dummy_node_1[:, :-1, :] = node_1 + dummy_node_2 = torch.zeros(adj_2.shape[0], node_2.shape[1] + 1, node_2.shape[-1], device=device) + dummy_node_2[:, :-1, :] = node_2 + k_diag = node_metric(dummy_node_1, dummy_node_2) + + mask_1 = torch.zeros_like(dummy_adj_1) + mask_2 = torch.zeros_like(dummy_adj_2) + for b in range(batch_num): + mask_1[b, :ns_1[b] + 1, :ns_1[b] + 1] = 1 + mask_1[b, :ns_1[b], :ns_1[b]] -= torch.eye(ns_1[b], device=mask_1.device) + mask_2[b, :ns_2[b] + 1, :ns_2[b] + 1] = 1 + mask_2[b, :ns_2[b], :ns_2[b]] -= torch.eye(ns_2[b], device=mask_2.device) + + a1 = dummy_adj_1.reshape(batch_num, -1, 1) + a2 = dummy_adj_2.reshape(batch_num, 1, -1) + m1 = mask_1.reshape(batch_num, -1, 1) + m2 = mask_2.reshape(batch_num, 1, -1) + k = torch.abs(a1 - a2) * torch.bmm(m1, m2) + k[torch.logical_not(torch.bmm(m1, m2).to(dtype=torch.bool))] = VERY_LARGE_INT + k = k.reshape(batch_num, dummy_adj_1.shape[1], dummy_adj_1.shape[2], dummy_adj_2.shape[1], dummy_adj_2.shape[2]) + k = k.permute([0, 1, 3, 2, 4]) + k = k.reshape(batch_num, dummy_adj_1.shape[1] * dummy_adj_2.shape[1], + dummy_adj_1.shape[2] * dummy_adj_2.shape[2]) + k = k / 2 + + for b in range(batch_num): + k_diag_view = torch.diagonal(k[b]) + k_diag_view[:] = k_diag[b].reshape(-1) + + self.reset_cache() + + x_pred, _ = a_star( + data, k, ns_1.cpu().numpy(), ns_2.cpu().numpy(), + self.net_prediction_cache, + self.heuristic_prediction_hun, + net_pred=self.args['use_net'], + beam_width=self.args['astar_beam_width'], + trust_fact=self.args['astar_trust_fact'], + no_pred_size=self.args['astar_no_pred'], + ) + + return x_pred + + def net_prediction_cache(self, data: GraphPair, partial_pmat=None, return_ged_norm=False): + """ + Forward pass with graphs. + :param data: Data class. + :param partial_pmat: Matched matrix. + :param return_ged_norm: Whether to return to Normal Graph Edit Distance. + :return score: Similarity score. + """ + features_1 = data.g1.x.squeeze() + features_2 = data.g2.x.squeeze() + batch_1 = data.g1.batch + batch_2 = data.g2.batch + adj1 = data.g1.adj.squeeze() + adj2 = data.g2.adj.squeeze() + + if 'gnn_feat' not in self.gnn_1_cache: + abstract_features_1 = self.convolutional_pass(adj1, features_1) + self.gnn_1_cache['gnn_feat'] = abstract_features_1 + else: + abstract_features_1 = self.gnn_1_cache['gnn_feat'] + if 'gnn_feat' not in self.gnn_2_cache: + abstract_features_2 = self.convolutional_pass(adj2, features_2) + self.gnn_2_cache['gnn_feat'] = abstract_features_2 + else: + abstract_features_2 = self.gnn_2_cache['gnn_feat'] + + graph_1_mask = torch.ones_like(batch_1) + graph_2_mask = torch.ones_like(batch_2) + graph_1_matched = partial_pmat.sum(dim=-1).to(dtype=torch.bool)[:graph_1_mask.shape[0]] + graph_2_matched = partial_pmat.sum(dim=-2).to(dtype=torch.bool)[:graph_2_mask.shape[0]] + graph_1_mask = torch.logical_not(graph_1_matched) + graph_2_mask = torch.logical_not(graph_2_matched) + abstract_features_1 = abstract_features_1[graph_1_mask] + abstract_features_2 = abstract_features_2[graph_2_mask] + batch_1 = batch_1[graph_1_mask] + batch_2 = batch_2[graph_2_mask] + pooled_features_1 = self.attention(abstract_features_1, batch_1) + pooled_features_2 = self.attention(abstract_features_2, batch_2) + scores = self.tensor_network(pooled_features_1, pooled_features_2) + score = self.scoring_layer(scores).view(-1) + + if return_ged_norm: + return score + else: + ged = - torch.log(score) * (batch_1.shape[0] + batch_2.shape[0]) / 2 + return ged + + def heuristic_prediction_hun(self, k: torch.Tensor, n1, n2, partial_pmat): + k_prime = k.reshape(-1, n1 + 1, n2 + 1) + node_costs = torch.empty(k_prime.shape[0]) + for i in range(k_prime.shape[0]): + _, node_costs[i] = hungarian_ged(k_prime[i], n1, n2) + node_cost_mat = node_costs.reshape(n1 + 1, n2 + 1) + self.heuristic_cache['node_cost'] = node_cost_mat + + graph_1_mask = ~partial_pmat.sum(dim=-1).to(dtype=torch.bool) + graph_2_mask = ~partial_pmat.sum(dim=-2).to(dtype=torch.bool) + graph_1_mask[-1] = 1 + graph_2_mask[-1] = 1 + node_cost_mat = node_cost_mat[graph_1_mask, :] + node_cost_mat = node_cost_mat[:, graph_2_mask] + + _, ged = hungarian_ged(node_cost_mat, torch.sum(graph_1_mask[:-1]), torch.sum(graph_2_mask[:-1])) + + return ged + + +def hungarian_ged(node_cost_mat: torch.Tensor, n1, n2): + assert node_cost_mat.shape[-2] == n1 + 1 + assert node_cost_mat.shape[-1] == n2 + 1 + device = node_cost_mat.device + upper_left = node_cost_mat[:n1, :n2] + upper_right = torch.full((n1, n1), float('inf'), device=device) + torch.diagonal(upper_right)[:] = node_cost_mat[:-1, -1] + lower_left = torch.full((n2, n2), float('inf'), device=device) + torch.diagonal(lower_left)[:] = node_cost_mat[-1, :-1] + lower_right = torch.zeros((n2, n1), device=device) + large_cost_mat = torch.cat((torch.cat((upper_left, upper_right), dim=1), + torch.cat((lower_left, lower_right), dim=1)), dim=0) + + large_pred_x = hungarian(-large_cost_mat.unsqueeze(dim=0)).squeeze() + pred_x = torch.zeros_like(node_cost_mat) + pred_x[:n1, :n2] = large_pred_x[:n1, :n2] + pred_x[:-1, -1] = torch.sum(large_pred_x[:n1, n2:], dim=1) + pred_x[-1, :-1] = torch.sum(large_pred_x[n1:, :n2], dim=0) + + ged_lower_bound = torch.sum(pred_x * node_cost_mat) + return pred_x, ged_lower_bound + + +def astar(feat1, feat2, A1, A2, n1, n2, channel, dropout, beam_width, trust_fact, no_pred_size): + """ + Pytorch implementation of ASTAR + """ + return astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, dropout=dropout, beam_width=beam_width, + filters_1=64, filters_2=32, filters_3=16, tensor_neurons=16, trust_fact=trust_fact, + no_pred_size=no_pred_size, pretrain=False, network=None, use_net=False) + + +def astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3, + tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain, use_net): + """ + The true implementation of astar and genn_astar functions + """ + if feat1 is None: + forward_pass = False + device = torch.device('cpu') + else: + assert feat1.shape[-1] == feat2.shape[-1], 'The feature dimensions of feat1 and feat2 must be consistent' + forward_pass = True + device = feat1.device + + if network is None: + args = default_parameter() + if forward_pass: + if channel is None: + args['channel'] = feat1.shape[-1] + else: + assert feat1.shape[-1] == channel, 'the channel {} must match the feature dimension of feat1\n'.format( + channel) + args['channel'] = channel + else: + if channel is None: + args['channel'] = 8 if pretrain == "LINUX" else 36 + else: + args['channel'] = channel + + args['filters_1'] = filters_1 + args['filters_2'] = filters_2 + args['filters_3'] = filters_3 + args['tensor_neurons'] = tensor_neurons + args['dropout'] = dropout + args['astar_beam_width'] = beam_width + args['astar_trust_fact'] = trust_fact + args['astar_no_pred'] = no_pred_size + args['pretrain'] = pretrain + args['use_net'] = use_net + + network = GENN(args) + + network = network.to(device) + if pretrain and args['use_net']: + if pretrain in astar_pretrain_path: + url, md5 = astar_pretrain_path[pretrain] + filename = pygmtools.utils.download(f'best_genn_{pretrain}_gcn_astar.pt', url, md5) + if check_layer_parameter(args): + _load_model(network, filename, device) + else: + _load_part_model(network, filename, device) + message = 'Warning: Pretrain {} does not support the parameters you entered, '.format(pretrain) + if args['pretrain'] == 'AIDS700nef': + message += "Supported parameters: ( channel:36, filters:(64,32,16) ), " + elif args['pretrain'] == 'LINUX': + message += "Supported parameters: ( channel:8, filters:(64,32,16) ), " + message += 'Input parameters: ( channel:{}, filters:({},{},{}) )'.format(args['channel'], \ + args['filters_1'], args['filters_2'], args['filters_3']) + print(message) + else: + raise ValueError(f'Unknown pretrain tag. Available tags: {astar_pretrain_path.keys()}') + + if forward_pass: + assert A1.shape[0] == A2.shape[0] + data = GraphPair(feat1, feat2, A1, A2, n1, n2) + result = network(data) + else: + result = None + return result, network + + ############################################ # Neural Network Solvers # ############################################ @@ -848,6 +1186,7 @@ class PCA_GM_Net(torch.nn.Module): """ Pytorch implementation of PCA-GM and IPCA-GM network """ + def __init__(self, in_channel, hidden_channel, out_channel, num_layers, cross_iter_num=-1): super(PCA_GM_Net, self).__init__() self.gnn_layer = num_layers @@ -863,8 +1202,7 @@ def __init__(self, in_channel, hidden_channel, out_channel, num_layers, cross_it if i == self.gnn_layer - 2: # only the second last layer will have cross-graph module self.add_module('cross_graph_{}'.format(i), torch.nn.Linear(hidden_channel * 2, hidden_channel)) if cross_iter_num <= 0: - self.add_module('affinity_{}'.format(i), WeightedInnerProdAffinity(hidden_channel)) - + self.add_module('affinity_{}'.format(i), WeightedInnerProdAffinity(hidden_channel)) def forward(self, feat1, feat2, A1, A2, n1, n2, cross_iter_num, sk_max_iter, sk_tau): _sinkhorn_func = functools.partial(sinkhorn, @@ -972,8 +1310,8 @@ def pca_gm(feat1, feat2, A1, A2, n1, n2, def ipca_gm(feat1, feat2, A1, A2, n1, n2, - in_channel, hidden_channel, out_channel, num_layers, cross_iter, sk_max_iter, sk_tau, - network, pretrain): + in_channel, hidden_channel, out_channel, num_layers, cross_iter, sk_max_iter, sk_tau, + network, pretrain): """ Pytorch implementation of IPCA-GM """ @@ -1009,6 +1347,7 @@ class CIE_Net(torch.nn.Module): """ Pytorch implementation of CIE graph matching network """ + def __init__(self, in_node_channel, in_edge_channel, hidden_channel, out_channel, num_layers): super(CIE_Net, self).__init__() self.gnn_layer = num_layers @@ -1099,6 +1438,7 @@ class NGM_Net(torch.nn.Module): """ Pytorch implementation of NGM network """ + def __init__(self, gnn_channels, sk_emb): super(NGM_Net, self).__init__() self.gnn_layer = len(gnn_channels) @@ -1173,6 +1513,15 @@ def ngm(K, n1, n2, n1max, n2max, x0, gnn_channels, sk_emb, sk_max_iter, sk_tau, return result, network +def genn_astar(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3, + tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain): + """ + Pytorch implementation of GENN-ASTAR + """ + return astar_kernel(feat1, feat2, A1, A2, n1, n2, channel, filters_1, filters_2, filters_3, + tensor_neurons, dropout, beam_width, trust_fact, no_pred_size, network, pretrain, use_net=True) + + ############################################# # Utils Functions # ############################################# @@ -1222,7 +1571,8 @@ def build_batch(input, return_ori_dim=False): padded_ts.append(torch.nn.functional.pad(t, pad_pattern, 'constant', 0)) if return_ori_dim: - return torch.stack(padded_ts, dim=0), tuple([torch.tensor(_, dtype=torch.int64, device=device) for _ in ori_shape]) + return torch.stack(padded_ts, dim=0), tuple( + [torch.tensor(_, dtype=torch.int64, device=device) for _ in ori_shape]) else: return torch.stack(padded_ts, dim=0) @@ -1361,8 +1711,10 @@ def _aff_mat_from_node_edge_aff(node_aff: Tensor, edge_aff: Tensor, connectivity if edge_aff is not None: conn1 = connectivity1[b][:ne1[b]] conn2 = connectivity2[b][:ne2[b]] - edge_indices = torch.cat([conn1.repeat_interleave(ne2[b], dim=0), conn2.repeat(ne1[b], 1)], dim=1) # indices: start_g1, end_g1, start_g2, end_g2 - edge_indices = (edge_indices[:, 2], edge_indices[:, 0], edge_indices[:, 3], edge_indices[:, 1]) # indices: start_g2, start_g1, end_g2, end_g1 + edge_indices = torch.cat([conn1.repeat_interleave(ne2[b], dim=0), conn2.repeat(ne1[b], 1)], + dim=1) # indices: start_g1, end_g1, start_g2, end_g2 + edge_indices = (edge_indices[:, 2], edge_indices[:, 0], edge_indices[:, 3], + edge_indices[:, 1]) # indices: start_g2, start_g1, end_g2, end_g1 k[edge_indices] = edge_aff[b, :ne1[b], :ne2[b]].reshape(-1) k = k.reshape(n2max * n1max, n2max * n1max) # node-wise affinity @@ -1451,3 +1803,25 @@ def _load_model(model, path, device, strict=True): if len(missing_keys) > 0: print('Warning: Missing key(s) in state_dict: {}. '.format( ', '.join('"{}"'.format(k) for k in missing_keys))) + + +def _load_part_model(model, path, device): + """ + Load PyTorch model from a given path. This function is used for some parameters' size mismatching + """ + if isinstance(model, torch.nn.DataParallel): + module = model.module + else: + module = model + model_state_dict = module.state_dict() + load_state_dict = torch.load(path, map_location=device) + skip_keys= list() + for name, param in load_state_dict.items(): + if name in model_state_dict: + try: + model_state_dict[name].copy_(param) + except: + skip_keys.append(name) + if len(skip_keys) > 0: + print('Warning: Skip key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in skip_keys))) diff --git a/pygmtools/utils.py b/pygmtools/utils.py index f2c75526..e496cca5 100644 --- a/pygmtools/utils.py +++ b/pygmtools/utils.py @@ -26,6 +26,8 @@ import wget import numpy as np import pygmtools +import networkx as nx +import urllib.request NOT_IMPLEMENTED_MSG = \ 'The backend function for {} is not implemented. ' \ @@ -1200,6 +1202,7 @@ def _mm(input1, input2, backend=None): ) return fn(*args) + def download(filename, url, md5=None, retries=10, to_cache=True): r""" Check if content exits. If not, download the content to ``/pygmtools/``. ```` @@ -1207,7 +1210,7 @@ def download(filename, url, md5=None, retries=10, to_cache=True): :param filename: the destination file name :param url: the url :param md5: (optional) the md5sum to verify the content. It should match the result of ``md5sum file`` on Linux. - :param retries: (default: 5) max number of retries + :param retries: (default: 10) max number of retries :return: the full path to the file: ``/pygmtools/`` """ if retries <= 0: @@ -1220,7 +1223,7 @@ def download(filename, url, md5=None, retries=10, to_cache=True): filename = os.path.join(dirs, filename) if not os.path.exists(filename): print(f'\nDownloading to {filename}...') - if retries % 2 == 1: + if retries % 3 == 1: try: down_res = requests.get(url, stream=True) file_size = int(down_res.headers.get('Content-Length', 0)) @@ -1230,11 +1233,17 @@ def download(filename, url, md5=None, retries=10, to_cache=True): except requests.exceptions.ConnectionError as err: print('Warning: Network error. Retrying...\n', err) return download(filename, url, md5, retries - 1) - else: + elif retries % 3 == 2: try: wget.download(url,out=filename) except: return download(filename, url, md5, retries - 1) + else: + try: + urllib.request.urlretrieve(url, filename) + except: + return download(filename, url, md5, retries - 1) + if md5 is not None: md5_returned = _get_md5(filename) if md5 != md5_returned: @@ -1244,6 +1253,7 @@ def download(filename, url, md5=None, retries=10, to_cache=True): return download(filename, url, md5, retries - 1) return filename + def _get_md5(filename): hash_md5 = hashlib.md5() chunk = 8192 @@ -1255,3 +1265,254 @@ def _get_md5(filename): hash_md5.update(buffer) md5_returned = hash_md5.hexdigest() return md5_returned + + +################################################### +# Support NetworkX and GraphML formats # +################################################### + + +def build_aff_mat_from_networkx(G1:nx.Graph, G2:nx.Graph, node_aff_fn=None, edge_aff_fn=None, backend=None): + r""" + Convert networkx object to Adjacency matrix + + :param G1: networkx object, whose type must be networkx.Graph + :param G2: networkx object, whose type must be networkx.Graph + :param node_aff_fn: (default: inner_prod_aff_fn) the node affinity function with the characteristic + ``node_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two node feature tensors and + outputs the node-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an + example. + :param edge_aff_fn: (default: inner_prod_aff_fn) the edge affinity function with the characteristic + ``edge_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two edge feature tensors and + outputs the edge-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an + example. + :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation. + :return: the affinity matrix corresponding to the networkx object G1 and G2 + + .. dropdown:: Example + + :: + + >>> import networkx as nx + >>> import pygmtools as pygm + >>> pygm.BACKEND = 'numpy' + + # Generate networkx images + >>> G1 = nx.DiGraph() + >>> G1.add_weighted_edges_from([(1, 2, 0.5), (2, 3, 0.8), (3, 4, 0.7)]) + >>> G2 = nx.DiGraph() + >>> G2.add_weighted_edges_from([(1, 2, 0.3), (2, 3, 0.6), (3, 4, 0.9), (4, 5, 0.4)]) + + # Obtain Affinity Matrix + >>> K = pygm.utils.build_aff_mat_from_networkx(G1, G2) + >>> K.shape + (20,20) + + # The affinity matrices K can be further processed by GM solvers + """ + if backend is None: + backend = pygmtools.BACKEND + A1 = from_numpy(np.asarray(from_networkx(G1))) + A2 = from_numpy(np.asarray(from_networkx(G2))) + conn1, edge1 = dense_to_sparse(A1, backend=backend) + conn2, edge2 = dense_to_sparse(A2, backend=backend) + K = build_aff_mat(None, edge1, conn1, None, edge2, conn2, node_aff_fn=node_aff_fn, edge_aff_fn=edge_aff_fn, backend=backend) + return K + + +def build_aff_mat_from_graphml(G1_path, G2_path, node_aff_fn=None, edge_aff_fn=None, backend=None): + r""" + Convert networkx object to Adjacency matrix + + :param G1_path: The file path of the graphml object + :param G2_path: The file path of the graphml object + :param node_aff_fn: (default: inner_prod_aff_fn) the node affinity function with the characteristic + ``node_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two node feature tensors and + outputs the node-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an + example. + :param edge_aff_fn: (default: inner_prod_aff_fn) the edge affinity function with the characteristic + ``edge_aff_fn(2D Tensor, 2D Tensor) -> 2D Tensor``, which accepts two edge feature tensors and + outputs the edge-wise affinity tensor. See :func:`~pygmtools.utils.inner_prod_aff_fn` as an + example. + :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation. + :return: the affinity matrix corresponding to the graphml object G1 and G2 + + + .. dropdown:: Example + + :: + + >>> import pygmtools as pygm + >>> pygm.BACKEND = 'numpy' + + # example file (.graphml) path + >>> G1_path = 'examples/data/graph1.graphml' + >>> G2_path = 'examples/data/graph2.graphml' + + # Obtain Affinity Matrix + >>> K = pygm.utils.build_aff_mat_from_graphml(G1_path, G2_path) + >>> K.shape + (121,121) + + # The affinity matrices K can be further processed by GM solvers + """ + if backend is None: + backend = pygmtools.BACKEND + A1 = from_numpy(np.asarray(from_graphml(G1_path))) + A2 = from_numpy(np.asarray(from_graphml(G2_path))) + conn1, edge1 = dense_to_sparse(A1, backend=backend) + conn2, edge2 = dense_to_sparse(A2, backend=backend) + K = build_aff_mat(None, edge1, conn1, None, edge2, conn2, node_aff_fn=node_aff_fn, edge_aff_fn=edge_aff_fn, backend=backend) + return K + + +def from_networkx(G:nx.Graph): + r""" + Convert networkx object to Adjacency matrix + + :param G: networkx object, whose type must be networkx.Graph + :return: the adjacency matrix corresponding to the networkx object + + .. dropdown:: Example + + :: + + >>> import networkx as nx + >>> import pygmtools as pygm + >>> pygm.BACKEND = 'numpy' + + # Generate networkx graphs + >>> G1 = nx.DiGraph() + >>> G1.add_weighted_edges_from([(1, 2, 0.5), (2, 3, 0.8), (3, 4, 0.7)]) + >>> G2 = nx.DiGraph() + >>> G2.add_weighted_edges_from([(1, 2, 0.3), (2, 3, 0.6), (3, 4, 0.9), (4, 5, 0.4)]) + + # Obtain Adjacency matrix + >>> pygm.utils.from_networkx(G1) + matrix([[0. , 0.5, 0. , 0. ], + [0. , 0. , 0.8, 0. ], + [0. , 0. , 0. , 0.7], + [0. , 0. , 0. , 0. ]]) + + >>> pygm.utils.from_networkx(G2) + matrix([[0. , 0.3, 0. , 0. , 0. ], + [0. , 0. , 0.6, 0. , 0. ], + [0. , 0. , 0. , 0.9, 0. ], + [0. , 0. , 0. , 0. , 0.4], + [0. , 0. , 0. , 0. , 0. ]]) + + """ + is_directed = isinstance(G, nx.DiGraph) + adj_matrix = nx.to_numpy_matrix(G,nodelist=G.nodes()) if is_directed else nx.to_numpy_matrix(G) + return adj_matrix + + +def to_networkx(adj_matrix, backend=None): + """ + Convert adjacency matrix to NetworkX object + + :param adj_matrix: the adjacency matrix to convert + :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation. + :return: the NetworkX object corresponding to the adjacency matrix + + .. dropdown:: Example + + :: + + >>> import networkx as nx + >>> import pygmtools as pygm + >>> pygm.BACKEND = 'numpy' + + # Generate adjacency matrix + >>> adj_matrix = np.random.random(size=(4,4)) + + # Obtain NetworkX object + >>> pygm.utils.to_networkx(adj_matrix) + + """ + if backend is None: + backend = pygmtools.BACKEND + adj_matrix = to_numpy(adj_matrix, backend=backend) + + if adj_matrix.ndim == 3 and adj_matrix.shape[0] == 1: + adj_matrix.squeeze(0) + assert adj_matrix.ndim == 2, 'Request the dimension of adj_matrix is 2' + + G = nx.DiGraph() if np.any(adj_matrix != adj_matrix.T) else nx.Graph() + G.add_nodes_from(range(adj_matrix.shape[0])) + for i, j in zip(*np.where(adj_matrix)): + G.add_edge(i, j, weight=adj_matrix[i, j]) + return G + + +def from_graphml(filename): + r""" + Convert graphml object to Adjacency matrix + + :param filename: graphml file path + :return: the adjacency matrix corresponding to the graphml object + + .. dropdown:: Example + + :: + + >>> import pygmtools as pygm + >>> pygm.BACKEND = 'numpy' + + # example file (.graphml) path + >>> G1_path = 'examples/data/graph1.graphml' + >>> G2_path = 'examples/data/graph2.graphml' + + # Obtain Adjacency matrix + >>> G1 = pygm.utils.from_graphml(G1_path) + >>> G1.shape + (11,11) + + >>> G1 = pygm.utils.from_graphml(G2_path) + >>> G2.shape + (11,11) + """ + if not filename.endswith('.graphml'): + raise ValueError("File name should end with '.graphml'") + if not os.path.isfile(filename): + raise ValueError("File not found: {}".format(filename)) + return from_networkx(nx.read_graphml(filename)) + + +def to_graphml(adj_matrix, filename, backend=None): + r""" + Write an adjacency matrix to a GraphML file + + :param adj_matrix: numpy.ndarray, the adjacency matrix to write + :param filename: str, the name of the output file + :param backend: (default: ``pygmtools.BACKEND`` variable) the backend for computation. + + .. dropdown:: Example + + :: + + >>> import pygmtools as pygm + >>> import numpy as np + >>> pygm.BACKEND = 'numpy' + + # Generate adjacency matrix + >>> adj_matrix = np.random.random(size=(4,4)) + >>> filename = 'examples/data/test.graphml' + >>> adj_matrix + array([[0.29440151, 0.66468829, 0.05403941, 0.85887567], + [0.48120964, 0.01429095, 0.73536659, 0.02962113], + [0.3815578 , 0.93356234, 0.01332568, 0.61149257], + [0.15422904, 0.64656912, 0.93219422, 0.784769 ]]) + + # Write GraphML file + >>> pygm.utils.to_graphml(adj_matrix, filename) + + # Check the generated GraphML file + >>> pygm.utils.from_graphml(filename) + array([[0.29440151, 0.66468829, 0.05403941, 0.85887567], + [0.48120964, 0.01429095, 0.73536659, 0.02962113], + [0.3815578 , 0.93356234, 0.01332568, 0.61149257], + [0.15422904, 0.64656912, 0.93219422, 0.784769 ]]) + """ + nx.write_graphml(to_networkx(adj_matrix, backend), filename) + \ No newline at end of file diff --git a/setup.py b/setup.py index 5f1f3385..4a11a3af 100644 --- a/setup.py +++ b/setup.py @@ -9,14 +9,40 @@ import sys from shutil import rmtree import re - +import platform from setuptools import find_packages, setup, Command +import tarfile +import distro +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +import shutil + def get_property(prop, project): result = re.search(r'{}\s*=\s*[\'"]([^\'"]*)[\'"]'.format(prop), open(project + '/__init__.py').read()) return result.group(1) +def get_os_and_python_version(): + system = platform.system() + python_version = ".".join(map(str, sys.version_info[:2])) + if system.lower() == "windows": + os_version = "windows" + elif system.lower() == "darwin": + os_version = "macos" + elif system.lower() == "linux": + os_version = distro.name().lower() + else: + raise ValueError("Unknown System") + if (python_version == '3.11'): + python_version = '3.10' + return os_version, python_version + + +def untar_file(tar_file_path, extract_folder_path): + with tarfile.open(tar_file_path, 'r:gz') as tarObj: + tarObj.extractall(extract_folder_path) + + # Package meta-data. NAME = 'pygmtools' DESCRIPTION = 'pygmtools provides graph matching solvers in Python API and supports numpy and pytorch backends. ' \ @@ -24,11 +50,21 @@ def get_property(prop, project): URL = 'https://pygmtools.readthedocs.io/' AUTHOR = get_property('__author__', NAME) VERSION = get_property('__version__', NAME) - REQUIRED = [ 'requests>=2.25.1', 'scipy>=1.4.1', 'Pillow>=7.2.0', 'numpy>=1.18.5', 'easydict>=1.7', 'appdirs>=1.4.4', 'tqdm>=4.64.1','wget>=3.2' ] - +FILE={'windows':{ '3.7':'a_star.cp37-win_amd64.pyd', + '3.8':'a_star.cp38-win_amd64.pyd', + '3.9':'a_star.cp39-win_amd64.pyd', + '3.10':'a_star.cp310-win_amd64.pyd'}, + 'macos' :{ '3.7':'a_star.cpython-37m-darwin.so', + '3.8':'a_star.cpython-38-darwin.so', + '3.9':'a_star.cpython-39-darwin.so', + '3.10':'a_star.cpython-310-darwin.so'}, + 'ubuntu' :{ '3.7':'a_star.cpython-37m-x86_64-linux-gnu.so', + '3.8':'a_star.cpython-38-x86_64-linux-gnu.so', + '3.9':'a_star.cpython-39-x86_64-linux-gnu.so', + '3.10':'a_star.cpython-310-x86_64-linux-gnu.so'}} EXTRAS = {} here = os.path.abspath(os.path.dirname(__file__)) @@ -50,6 +86,18 @@ def get_property(prop, project): else: about['__version__'] = VERSION +if os.path.exists(os.path.join(NAME,'astar/a_star.tar.gz')): + untar_file(os.path.join(NAME,'astar/a_star.tar.gz'),os.path.join(NAME,'astar')) + + +class CustomBdistWheelCommand(_bdist_wheel): + def run(self): + os_version, python_version = get_os_and_python_version() + dynamic_link = FILE[os_version][python_version] + shutil.copy2(os.path.join(NAME, 'astar', dynamic_link), os.path.join(NAME, dynamic_link)) + shutil.rmtree(os.path.join(NAME, 'astar')) + _bdist_wheel.run(self) + class UploadCommand(Command): """Support setup.py upload.""" @@ -97,12 +145,13 @@ def run(self): author=AUTHOR, url=URL, packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), + package_data={NAME: ['astar/*','*.pyd','*.so']}, install_requires=REQUIRED, extras_require=EXTRAS, include_package_data=True, license='Mulan PSL v2', python_requires='>=3.7', - classifiers=( + classifiers=[ 'License :: OSI Approved :: Mulan Permissive Software License v2 (MulanPSL-2.0)', 'Programming Language :: Python :: 3 :: Only', 'Operating System :: OS Independent', @@ -111,9 +160,10 @@ def run(self): 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Image Recognition', 'Topic :: Scientific/Engineering :: Mathematics', - ), + ], # $ setup.py publish support. cmdclass={ 'upload': UploadCommand, + 'bdist_wheel': CustomBdistWheelCommand, }, ) diff --git a/tests/requirements.txt b/tests/requirements.txt index 258519a7..a1b9cfe1 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -10,4 +10,8 @@ tqdm jittor==1.3.5.37 appdirs>=1.4.4 tensorflow==2.9.3 +distro +cython wget +networkx==2.8.8 + diff --git a/tests/requirements_win_mac.txt b/tests/requirements_win_mac.txt index 948db8d6..5b0fcd8d 100644 --- a/tests/requirements_win_mac.txt +++ b/tests/requirements_win_mac.txt @@ -10,4 +10,7 @@ tqdm appdirs>=1.4.4 tensorflow==2.9.3 mindspore==1.10.0 +distro +cython wget +networkx==2.8.8 \ No newline at end of file diff --git a/tests/test_a_star/a_star.pyx b/tests/test_a_star/a_star.pyx new file mode 100644 index 00000000..619ead7f --- /dev/null +++ b/tests/test_a_star/a_star.pyx @@ -0,0 +1,166 @@ +# distutils: language = c++ +import torch +import numpy as np +cimport cython +cimport numpy as np +#from libcpp.queue cimport priority_queue +from libcpp.vector cimport vector +from libcpp.pair cimport pair +from libcpp cimport bool + +cdef extern from "priority_queue.hpp": + cdef cppclass TreeNode: + TreeNode() + TreeNode(long) + #TreeNode(vector[pair[long, long]], double, long) + pair[vector[long], vector[long]] x_indices + double gplsh + long idx + + cdef cppclass tree_node_priority_queue: + tree_node_priority_queue(...) # get Cython to accept any arguments and let C++ deal with getting them right + void push(TreeNode) + TreeNode top() + void pop() + bool empty() + long size() + +@cython.boundscheck(False) +@cython.wraparound(False) +def a_star( + data, + k, + vector[long] ns_1, + vector[long] ns_2, + net_pred_func, + heuristic_func, + bool net_pred=True, + long beam_width=0, + double trust_fact=1., + long no_pred_size=0, +): + # declare static dtypes + cdef long batch_num, b, n1, n2, _n2, ns_1b, ns_2b, max_ns_1, max_ns_2, extra_n2_cnt + cdef vector[long] tree_size + cdef double h_p, g_p + cdef vector[tree_node_priority_queue] open_set + cdef tree_node_priority_queue cur_set + cdef TreeNode selected, new_node + cdef vector[bool] stop_flags + cdef bool flag + + batch_num = k.shape[0] + + max_ns_1 = max(ns_1) + max_ns_2 = max(ns_2) + + open_set = vector[tree_node_priority_queue](batch_num) + tree_size = vector[long](batch_num) + for b in range(batch_num): + open_set[b].push(TreeNode()) + ret_x = torch.zeros(batch_num, max_ns_1+1, max_ns_2+1, device=k.device) + x_dense = torch.zeros(max_ns_1+1, max_ns_2+1, device=k.device) + stop_flags = vector[bool](batch_num, 0) + while not all(stop_flags): + for b in range(batch_num): + ns_1b = ns_1[b] + ns_2b = ns_2[b] + + if stop_flags[b] == 1: + continue + + selected = open_set[b].top() + open_set[b].pop() + #selected_x_indices = torch.tensor(selected.x_indices, dtype=torch.long).reshape(-1, 2) + if selected.idx == ns_1b: + stop_flags[b] = 1 + #indices = selected_x_indices + #v = torch.ones(indices.shape[0], device=k.device) + #x = torch.sparse.FloatTensor(indices.t(), v, x_size).to_dense() + ret_x[b][selected.x_indices] = 1 + continue + + if beam_width > 0: + cur_set = tree_node_priority_queue() + flag = False + for n2 in range(ns_2b + 1): + if n2 != ns_2b and is_in(n2, selected.x_indices.second): + continue + if selected.idx + 1 == ns_1b: + flag = True + extra_n2_cnt = 0 + for _n2 in range(ns_2b): + if _n2 != n2 and not is_in(_n2, selected.x_indices.second): + extra_n2_cnt += 1 + new_node = TreeNode(ns_1b + extra_n2_cnt) + n1 = 0 + for _ in range(selected.idx): + new_node.x_indices.first[n1] = selected.x_indices.first[n1] + new_node.x_indices.second[n1] = selected.x_indices.second[n1] + n1 += 1 + new_node.x_indices.first[n1] = selected.idx + new_node.x_indices.second[n1] = n2 + n1 += 1 + for _n2 in range(ns_2b): + if _n2 != n2 and not is_in(_n2, selected.x_indices.second): + new_node.x_indices.first[n1] = ns_1b + new_node.x_indices.second[n1] = _n2 + n1 += 1 + else: + new_node = TreeNode(selected.idx + 1) + n1 = 0 + for _ in range(selected.idx): + new_node.x_indices.first[n1] = selected.x_indices.first[n1] + new_node.x_indices.second[n1] = selected.x_indices.second[n1] + n1 += 1 + new_node.x_indices.first[n1] = selected.idx + new_node.x_indices.second[n1] = n2 + n1 += 1 + + x_dense[:] = 0 + x_dense[new_node.x_indices] = 1 + + g_p = comp_ged(x_dense, k[b]) + + if net_pred: + if selected.idx + 1 == ns_1b or trust_fact <= 0. or ns_1b - (selected.idx + 1) < no_pred_size: + h_p = 0 + else: + h_p = net_pred_func(data, x_dense) + else: + if selected.idx + 1 == ns_1b or trust_fact <= 0. or ns_1b - (selected.idx + 1) < no_pred_size: + h_p = 0 + else: + h_p = heuristic_func(k[b], ns_1b, ns_2b, x_dense) + + new_node.gplsh = g_p + h_p * trust_fact + new_node.idx = selected.idx + 1 + + if beam_width > 0: + cur_set.push(new_node) + else: + open_set[b].push(new_node) + tree_size[b] += 1 + if(flag): + break + + if beam_width > 0: + for i in range(min(beam_width, cur_set.size())): + open_set[b].push(cur_set.top()) + cur_set.pop() + tree_size[b] += 1 + + return ret_x, tree_size + + +cdef double comp_ged(_x, _k): + return torch.mm(torch.mm(_x.reshape( 1, -1), _k), _x.reshape( -1, 1)) + +cdef bool is_in (long inp, vector[long] vec): + cdef unsigned long i + cdef bool ret = False + for i in range(vec.size()): + if inp == vec[i]: + ret = True + break + return ret diff --git a/tests/test_a_star/a_star_setup.py b/tests/test_a_star/a_star_setup.py new file mode 100644 index 00000000..2e76312e --- /dev/null +++ b/tests/test_a_star/a_star_setup.py @@ -0,0 +1,18 @@ +from setuptools import setup, Extension +from Cython.Build import cythonize +import numpy as np +from glob import glob +setup( + name='a-star function', + ext_modules=cythonize( + Extension( + 'a_star', + glob('*.pyx'), + include_dirs=[np.get_include(),"."], + extra_compile_args=["-std=c++11"], + extra_link_args=["-std=c++11"], + ), + language_level = "3", + ), + zip_safe=False, +) diff --git a/tests/test_a_star/prepare_for_test.py b/tests/test_a_star/prepare_for_test.py new file mode 100644 index 00000000..16096693 --- /dev/null +++ b/tests/test_a_star/prepare_for_test.py @@ -0,0 +1,26 @@ +import os +import glob +import shutil + +ori_dir = os.getcwd() +os.chdir('tests/test_a_star') + +try: + os.system("python a_star_setup.py build_ext --inplace") +except: + os.system("python3 a_star_setup.py build_ext --inplace") + +current_dir = os.getcwd() + +ext_files = glob.glob(os.path.join(current_dir, '*.pyd')) + \ + glob.glob(os.path.join(current_dir, '*.so')) + +if len(ext_files) == 0: + raise ValueError("there is no .pyd or .so") +elif len(ext_files) > 1: + raise ValueError("too many files end with .pyd or .so") +else: + target_dir = os.path.abspath(os.path.join(current_dir,'..','..','pygmtools')) + shutil.copy(ext_files[0], target_dir) + +os.chdir(ori_dir) diff --git a/tests/test_a_star/priority_queue.hpp b/tests/test_a_star/priority_queue.hpp new file mode 100644 index 00000000..7c2905d6 --- /dev/null +++ b/tests/test_a_star/priority_queue.hpp @@ -0,0 +1,41 @@ +#include +#include + +struct TreeNode +{ + std::pair, std::vector > x_indices; + double gplsh; + long idx; + TreeNode(); + TreeNode(const int &); + TreeNode(const std::pair, std::vector > &, const double &, const long &); + bool operator>(const TreeNode &) const; +}; + +TreeNode::TreeNode() +{ + this->x_indices = std::pair, std::vector >(); + this->gplsh = 0; + this->idx = 0; +} + +TreeNode::TreeNode(const int & len) +{ + this->x_indices = std::pair, std::vector >(std::vector(len), std::vector(len)); + this->gplsh = 0; + this->idx = 0; +} + +TreeNode::TreeNode(const std::pair, std::vector > &x_indices, const double &gplsh, const long &idx) +{ + this->x_indices = x_indices; + this->gplsh = gplsh; + this->idx = idx; +} + +bool TreeNode::operator>(const TreeNode &c) const +{ + return this->gplsh > c.gplsh; +} + +using tree_node_priority_queue = std::priority_queue, std::greater >; diff --git a/tests/test_classic_solvers.py b/tests/test_classic_solvers.py index a5dbf9a6..fbd48ee2 100644 --- a/tests/test_classic_solvers.py +++ b/tests/test_classic_solvers.py @@ -27,6 +27,8 @@ def get_backends(backend): if backend == "all": backends = ['pytorch', 'numpy', 'paddle', 'jittor', 'tensorflow'] if os_name == 'Linux' else ['pytorch', 'numpy', 'paddle', 'tensorflow'] + elif backend == '': + backends = ['pytorch'] else: backends = ["pytorch", backend] return backends @@ -208,6 +210,110 @@ def _test_classic_solver_on_linear_assignment(num_nodes1, num_nodes2, node_feat_ last_X = pygm.utils.to_numpy(_X) +# The testing function for a_star +def _test_astar(graph_num_nodes, node_feat_dim, solver_func, matrix_params, backends): + if backends[0] != 'pytorch': + backends.insert(0, 'pytorch') # force pytorch as the reference backend + backends = ['pytorch'] # Due to currently only supporting pytorch, testing is only conducted under pytorch + batch_size = len(graph_num_nodes) + + # Generate isomorphic graphs + pygm.BACKEND = 'pytorch' + torch.manual_seed(0) + X_gt, A1, A2, F1, F2, = [], [], [], [], [], + for b, num_node in enumerate(graph_num_nodes): + As_b, X_gt_b, Fs_b = pygm.utils.generate_isomorphic_graphs(num_node, node_feat_dim=node_feat_dim) + Fs_b = Fs_b - 0.5 + X_gt.append(X_gt_b) + A1.append(As_b[0]) + A2.append(As_b[1]) + F1.append(Fs_b[0]) + F2.append(Fs_b[1]) + n1 = torch.tensor(graph_num_nodes, dtype=torch.int) + n2 = torch.tensor(graph_num_nodes, dtype=torch.int) + A1, A2, F1, F2, X_gt = (pygm.utils.build_batch(_) for _ in (A1, A2, F1, F2, X_gt)) + if batch_size > 1: + A1, A2, F1, F2, n1, n2, X_gt = data_to_numpy(A1, A2, F1, F2, n1, n2, X_gt) + else: + A1, A2, F1, F2, n1, n2, X_gt = data_to_numpy( + A1.squeeze(0), A2.squeeze(0), F1.squeeze(0), F2.squeeze(0), n1, n2, X_gt.squeeze(0) + ) + + # call the solver + total = 1 + for val in matrix_params.values(): + total *= len(val) + for values in tqdm(itertools.product(*matrix_params.values()), total=total): + solver_param_dict = {} + for k, v in zip(matrix_params.keys(), values): + solver_param_dict[k] = v + + last_X = None + for working_backend in backends: + pygm.BACKEND = working_backend + _A1, _A2, _F1, _F2, _n1, _n2 = data_from_numpy(A1, A2, F1, F2, n1, n2) + _X1 = solver_func(_F1, _F2, _A1, _A2, _n1, _n2, **solver_param_dict) + + if last_X is not None: + assert np.abs(pygm.utils.to_numpy(_X1) - last_X).sum() < 5e-3, \ + f"Incorrect GM solution for {working_backend}; " \ + f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}" + + last_X = pygm.utils.to_numpy(_X1) + accuracy = (pygm.utils.to_numpy(pygm.hungarian(_X1, _n1, _n2)) * X_gt).sum() / X_gt.sum() + assert accuracy == 1, f"GM is inaccurate for {working_backend}, accuracy={accuracy:.4f}; " \ + f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}" + + +# The testing function for networkx +def _test_networkx(graph_num_nodes, backends): + """ + Test the RRWM algorithm on pairs of isomorphic graphs using NetworkX + + :param graph_num_nodes: list, the numbers of nodes in the graphs to test + """ + for working_backend in backends: + pygm.BACKEND = working_backend + for num_node in tqdm(graph_num_nodes): + As_b, X_gt = pygm.utils.generate_isomorphic_graphs(num_node) + X_gt = pygm.utils.to_numpy(X_gt, backend=working_backend) + A1 = As_b[0] + A2 = As_b[1] + G1 = pygm.utils.to_networkx(A1) + G2 = pygm.utils.to_networkx(A2) + K = pygm.utils.build_aff_mat_from_networkx(G1, G2) + X = pygm.rrwm(K, n1=num_node, n2=num_node) + accuracy = (pygm.utils.to_numpy(pygm.hungarian(X, num_node, num_node)) * X_gt).sum() / X_gt.sum() + assert accuracy == 1, f'When testing the networkx function with rrwm algorithm, there is an error in accuracy, \ + and the accuracy is {accuracy}, the num_node is {num_node},.' + + +# The testing fuction for graphml +def _test_graphml(graph_num_nodes, backends): + """ + Test the RRWM algorithm on pairs of isomorphic graphs using graphml + + :param graph_num_nodes: list, the numbers of nodes in the graphs to test + """ + filename = 'examples/data/test_graphml_{}.graphml' + filename_1 = filename.format(1) + filename_2 = filename.format(2) + for working_backend in backends: + pygm.BACKEND = working_backend + for num_node in tqdm(graph_num_nodes): + As_b, X_gt = pygm.utils.generate_isomorphic_graphs(num_node) + X_gt = pygm.utils.to_numpy(X_gt, backend=working_backend) + A1 = As_b[0] + A2 = As_b[1] + pygm.utils.to_graphml(A1, filename_1, backend=working_backend) + pygm.utils.to_graphml(A2, filename_2, backend=working_backend) + K = pygm.utils.build_aff_mat_from_graphml(filename_1, filename_2) + X = pygm.rrwm(K, n1=num_node, n2=num_node) + accuracy = (pygm.utils.to_numpy(pygm.hungarian(X, num_node, num_node)) * X_gt).sum() / X_gt.sum() + assert accuracy == 1, f'When testing the graphml function with rrwm algorithm, there is an error in accuracy, \ + and the accuracy is {accuracy}, the num_node is {num_node},.' + + def test_hungarian(get_backend): backends = get_backends(get_backend) _test_classic_solver_on_linear_assignment(list(range(10, 30, 2)), list(range(30, 10, -2)), 10, pygm.hungarian, { @@ -343,6 +449,36 @@ def test_ipfp(get_backend): 'edge_aff_fn': [functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.)], 'node_aff_fn': [functools.partial(pygm.utils.gaussian_aff_fn, sigma=.1)] }, backends) + + +def test_astar(get_backend): + backends = get_backends(get_backend) + # heuristic_prediction + args1 = (list(range(10, 16, 2)), 10, pygm.astar,{ + "beam_width": [0, 1, 2], + "trust_fact": [0.9, 0.95, 1.0], + "no_pred_size": [0, 1], + }, backends) + + # non-batched input + args2 = ([10], 10, pygm.astar,{ + "beam_width": [0, 1, 2], + "trust_fact": [0.9, 0.95, 1.0], + "no_pred_size": [0, 1], + }, backends) + + _test_astar(*args1) + _test_astar(*args2) + + +def test_networkx(): + backends = ['pytorch', 'numpy'] + _test_networkx(list(range(10, 30, 2)), backends=backends) + + +def test_graphml(): + backends = ['pytorch', 'numpy'] + _test_graphml(list(range(10, 30, 2)), backends=backends) if __name__ == '__main__': @@ -351,3 +487,6 @@ def test_ipfp(get_backend): test_rrwm('all') test_sm('all') test_ipfp('all') + test_astar('') + test_networkx() + test_graphml() diff --git a/tests/test_neural_solvers.py b/tests/test_neural_solvers.py index d6fe7a42..64b1e28b 100644 --- a/tests/test_neural_solvers.py +++ b/tests/test_neural_solvers.py @@ -136,6 +136,71 @@ def _test_neural_solver_on_isomorphic_graphs(graph_num_nodes, node_feat_dim, sol f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}" +# The testing function for genn_astar +def _test_genn_astar(graph_num_nodes, node_feat_dim, solver_func, matrix_params, backends): + if backends[0] != 'pytorch': + backends.insert(0, 'pytorch') # force pytorch as the reference backend + backends = ['pytorch'] # Due to currently only supporting pytorch, testing is only conducted under pytorch + batch_size = len(graph_num_nodes) + + # Generate isomorphic graphs + pygm.BACKEND = 'pytorch' + torch.manual_seed(0) + X_gt, A1, A2, F1, F2, = [], [], [], [], [], + for b, num_node in enumerate(graph_num_nodes): + As_b, X_gt_b, Fs_b = pygm.utils.generate_isomorphic_graphs(num_node, node_feat_dim=node_feat_dim) + Fs_b = Fs_b - 0.5 + X_gt.append(X_gt_b) + A1.append(As_b[0]) + A2.append(As_b[1]) + F1.append(Fs_b[0]) + F2.append(Fs_b[1]) + n1 = torch.tensor(graph_num_nodes, dtype=torch.int) + n2 = torch.tensor(graph_num_nodes, dtype=torch.int) + A1, A2, F1, F2, X_gt = (pygm.utils.build_batch(_) for _ in (A1, A2, F1, F2, X_gt)) + if batch_size > 1: + A1, A2, F1, F2, n1, n2, X_gt = data_to_numpy(A1, A2, F1, F2, n1, n2, X_gt) + else: + A1, A2, F1, F2, n1, n2, X_gt = data_to_numpy( + A1.squeeze(0), A2.squeeze(0), F1.squeeze(0), F2.squeeze(0), n1, n2, X_gt.squeeze(0) + ) + + # call the solver + total = 1 + for val in matrix_params.values(): + total *= len(val) + for values in tqdm(itertools.product(*matrix_params.values()), total=total): + solver_param_dict = {} + for k, v in zip(matrix_params.keys(), values): + solver_param_dict[k] = v + + last_X = None + for working_backend in backends: + pygm.BACKEND = working_backend + _A1, _A2, _F1, _F2, _n1, _n2 = data_from_numpy(A1, A2, F1, F2, n1, n2) + _X1, net = solver_func(_F1, _F2, _A1, _A2, _n1, _n2, return_network=True, **solver_param_dict) + _X2 = solver_func(_F1, _F2, _A1, _A2, _n1, _n2, network=net, **solver_param_dict) + net2 = pygm.utils.get_network(solver_func, **solver_param_dict) + assert type(net) == type(net2) + + assert np.abs(pygm.utils.to_numpy(_X1) - pygm.utils.to_numpy(_X2)).sum() < 1e-4, \ + f"GM result inconsistent for predefined network object. backend={working_backend}; " \ + f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}" + + if 'pretrain' in solver_param_dict and solver_param_dict['pretrain'] is None: + _X1 = pygm.hungarian(_X1, _n1, _n2) + + if last_X is not None: + assert np.abs(pygm.utils.to_numpy(_X1) - last_X).sum() < 5e-3, \ + f"Incorrect GM solution for {working_backend}; " \ + f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}" + + last_X = pygm.utils.to_numpy(_X1) + accuracy = (pygm.utils.to_numpy(pygm.hungarian(_X1, _n1, _n2)) * X_gt).sum() / X_gt.sum() + assert accuracy == 1, f"GM is inaccurate for {working_backend}, accuracy={accuracy:.4f}; " \ + f"{';'.join([k + '=' + str(v) for k, v in solver_param_dict.items()])}" + + def test_pca_gm(): _test_neural_solver_on_isomorphic_graphs(list(range(10, 30, 2)), 1024, pygm.pca_gm, 'individual-graphs', { 'pretrain': ['voc', 'willow', 'voc-all'], @@ -186,6 +251,7 @@ def test_cie(): 'pretrain': [None], }, backends) + def test_ngm(): _test_neural_solver_on_isomorphic_graphs(list(range(10, 30, 2)), 1024, pygm.ngm, 'lawler-qap', { 'edge_aff_fn': [functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.), pygm.utils.inner_prod_aff_fn], @@ -201,8 +267,49 @@ def test_ngm(): }, backends) +def test_genn_astar(): + # test pretrained by AIDS700nef + args1 = (list(range(10, 30, 2)), 36, pygm.genn_astar,{ + "pretrain": ["AIDS700nef"], + "beam_width": [0, 1, 2], + "trust_fact": [0.9, 0.95, 1.0], + "no_pred_size": [0, 1], + + }, backends) + + # non-batched input + args2 = ([10], 36, pygm.genn_astar,{ + 'pretrain': ["AIDS700nef"], + "beam_width": [0, 1, 2], + "trust_fact": [0.9, 0.95, 1.0], + "no_pred_size": [0, 1], + }, backends) + + # test pretrained by LINUX + args3 = (list(range(10, 30, 2)), 8, pygm.genn_astar,{ + 'pretrain': ['LINUX'], + "beam_width": [0, 1, 2], + "trust_fact": [0.9, 0.95, 1.0], + "no_pred_size": [0, 1], + }, backends) + + # non-batched input + args4 = ([10], 8, pygm.genn_astar,{ + 'pretrain': ['LINUX'], + "beam_width": [0, 1, 2], + "trust_fact": [0.9, 0.95, 1.0], + "no_pred_size": [0, 1], + }, backends) + + _test_genn_astar(*args1) + _test_genn_astar(*args2) + _test_genn_astar(*args3) + _test_genn_astar(*args4) + + if __name__ == '__main__': test_pca_gm() test_ipca_gm() test_cie() test_ngm() + test_genn_astar() diff --git a/tests/test_utils.py b/tests/test_utils.py index c0c123e7..25cc4434 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -29,3 +29,4 @@ def data_to_numpy(*data): return return_list else: return return_list[0] + \ No newline at end of file