diff --git a/fastreid/evaluation/GPU-Re-Ranking/README.md b/fastreid/evaluation/GPU-Re-Ranking/README.md new file mode 100644 index 000000000..349a9ef63 --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/README.md @@ -0,0 +1,37 @@ +# Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + +[[Paper]](https://arxiv.org/abs/2012.07620v2) + +On the Market-1501 dataset, we accelerate the re-ranking processing from **89.2s** to **9.4ms** with one K40m GPU, facilitating the real-time post-processing. +Similarly, we observe that our method achieves comparable or even better retrieval results on the other four image retrieval benchmarks, +i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, with limited time cost. + +## Prerequisites + +The code was mainly developed and tested with python 3.7, PyTorch 1.4.1, CUDA 10.2, and CentOS release 6.10. + +The code has been included in `/extension`. To compile it: + +```shell +cd extension +sh make.sh +``` + +## Demo + +The demo script `main.py` provides the gnn re-ranking method using the prepared feature. + +```shell +python main.py --data_path PATH_TO_DATA --k1 26 --k2 7 +``` + +## Citation +```bibtex +@article{zhang2020understanding, + title={Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective}, + author={Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang}, + journal={arXiv preprint arXiv:2012.07620}, + year={2020} +} +``` + diff --git a/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/build_adjacency_matrix.cpp b/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/build_adjacency_matrix.cpp new file mode 100644 index 000000000..4c496041e --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/build_adjacency_matrix.cpp @@ -0,0 +1,19 @@ +#include +#include +#include + +at::Tensor build_adjacency_matrix_forward(torch::Tensor initial_rank); + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +at::Tensor build_adjacency_matrix(at::Tensor initial_rank) { + CHECK_INPUT(initial_rank); + return build_adjacency_matrix_forward(initial_rank); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &build_adjacency_matrix, "build_adjacency_matrix (CUDA)"); +} diff --git a/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/build_adjacency_matrix_kernel.cu b/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/build_adjacency_matrix_kernel.cu new file mode 100644 index 000000000..4973ddefe --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/build_adjacency_matrix_kernel.cu @@ -0,0 +1,31 @@ +#include + +#include +#include +#include + +#define CUDA_1D_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) + + +__global__ void build_adjacency_matrix_kernel(float* initial_rank, float* A, const int total_num, const int topk, const int nthreads, const int all_num) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < all_num; i += stride) { + int ii = i / topk; + A[ii * total_num + int(initial_rank[i])] = float(1.0); + } +} + +at::Tensor build_adjacency_matrix_forward(at::Tensor initial_rank) { + const auto total_num = initial_rank.size(0); + const auto topk = initial_rank.size(1); + const auto all_num = total_num * topk; + auto A = torch::zeros({total_num, total_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float)); + + const int threads = 1024; + const int blocks = (all_num + threads - 1) / threads; + + build_adjacency_matrix_kernel<<>>(initial_rank.data_ptr(), A.data_ptr(), total_num, topk, threads, all_num); + return A; + +} diff --git a/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/setup.py b/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/setup.py new file mode 100644 index 000000000..7b4e1e5e8 --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/extension/adjacency_matrix/setup.py @@ -0,0 +1,37 @@ +""" + Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + + Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang + + Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking + + Paper: https://arxiv.org/abs/2012.07620v2 + + ====================================================================== + + On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms + with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe + that our method achieves comparable or even better retrieval results on the other four + image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, + with limited time cost. +""" + +from setuptools import setup, Extension + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='build_adjacency_matrix', + ext_modules=[ + CUDAExtension('build_adjacency_matrix', [ + 'build_adjacency_matrix.cpp', + 'build_adjacency_matrix_kernel.cu', + ]), + ], + cmdclass={ + 'build_ext':BuildExtension + }) diff --git a/fastreid/evaluation/GPU-Re-Ranking/extension/make.sh b/fastreid/evaluation/GPU-Re-Ranking/extension/make.sh new file mode 100644 index 000000000..f0197ff9c --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/extension/make.sh @@ -0,0 +1,4 @@ +cd adjacency_matrix +python setup.py install +cd ../propagation +python setup.py install \ No newline at end of file diff --git a/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/gnn_propagate.cpp b/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/gnn_propagate.cpp new file mode 100644 index 000000000..10a939ffe --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/gnn_propagate.cpp @@ -0,0 +1,21 @@ +#include +#include +#include + +at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S); + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +at::Tensor gnn_propagate(at::Tensor A ,at::Tensor initial_rank, at::Tensor S) { + CHECK_INPUT(A); + CHECK_INPUT(initial_rank); + CHECK_INPUT(S); + return gnn_propagate_forward(A, initial_rank, S); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &gnn_propagate, "gnn propagate (CUDA)"); +} \ No newline at end of file diff --git a/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/gnn_propagate_kernel.cu b/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/gnn_propagate_kernel.cu new file mode 100644 index 000000000..8bdebf166 --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/gnn_propagate_kernel.cu @@ -0,0 +1,36 @@ +#include + +#include +#include +#include +#include + +__global__ void gnn_propagate_forward_kernel(float* initial_rank, float* A, float* A_qe, float* S, const int sample_num, const int topk, const int total_num) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < total_num; i += stride) { + int fea = i % sample_num; + int sample_index = i / sample_num; + float sum = 0.0; + for (int j = 0; j < topk ; j++) { + int topk_fea_index = int(initial_rank[sample_index*topk+j]) * sample_num + fea; + sum += A[ topk_fea_index] * S[sample_index*topk+j]; + } + A_qe[i] = sum; + } +} + +at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S) { + const auto sample_num = A.size(0); + const auto topk = initial_rank.size(1); + + const auto total_num = sample_num * sample_num ; + auto A_qe = torch::zeros({sample_num, sample_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float)); + + const int threads = 1024; + const int blocks = (total_num + threads - 1) / threads; + + gnn_propagate_forward_kernel<<>>(initial_rank.data_ptr(), A.data_ptr(), A_qe.data_ptr(), S.data_ptr(), sample_num, topk, total_num); + return A_qe; + +} \ No newline at end of file diff --git a/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/setup.py b/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/setup.py new file mode 100644 index 000000000..a22278651 --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/extension/propagation/setup.py @@ -0,0 +1,37 @@ +""" + Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + + Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang + + Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking + + Paper: https://arxiv.org/abs/2012.07620v2 + + ====================================================================== + + On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms + with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe + that our method achieves comparable or even better retrieval results on the other four + image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, + with limited time cost. +""" + +from setuptools import setup, Extension + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='gnn_propagate', + ext_modules=[ + CUDAExtension('gnn_propagate', [ + 'gnn_propagate.cpp', + 'gnn_propagate_kernel.cu', + ]), + ], + cmdclass={ + 'build_ext':BuildExtension + }) \ No newline at end of file diff --git a/fastreid/evaluation/GPU-Re-Ranking/gnn_reranking.py b/fastreid/evaluation/GPU-Re-Ranking/gnn_reranking.py new file mode 100644 index 000000000..cf0b5277a --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/gnn_reranking.py @@ -0,0 +1,57 @@ +""" + Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + + Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang + + Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking + + Paper: https://arxiv.org/abs/2012.07620v2 + + ====================================================================== + + On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms + with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe + that our method achieves comparable or even better retrieval results on the other four + image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, + with limited time cost. +""" + +import torch +import numpy as np + +import build_adjacency_matrix +import gnn_propagate + +from utils import * + + + +def gnn_reranking(X_q, X_g, k1, k2): + query_num, gallery_num = X_q.shape[0], X_g.shape[0] + + X_u = torch.cat((X_q, X_g), axis = 0) + original_score = torch.mm(X_u, X_u.t()) + del X_u, X_q, X_g + + # initial ranking list + S, initial_rank = original_score.topk(k=k1, dim=-1, largest=True, sorted=True) + + # stage 1 + A = build_adjacency_matrix.forward(initial_rank.float()) + S = S * S + + # stage 2 + if k2 != 1: + for i in range(2): + A = A + A.T + A = gnn_propagate.forward(A, initial_rank[:, :k2].contiguous().float(), S[:, :k2].contiguous().float()) + A_norm = torch.norm(A, p=2, dim=1, keepdim=True) + A = A.div(A_norm.expand_as(A)) + + + cosine_similarity = torch.mm(A[:query_num,], A[query_num:, ].t()) + del A, S + + L = torch.sort(-cosine_similarity, dim = 1)[1] + L = L.data.cpu().numpy() + return L diff --git a/fastreid/evaluation/GPU-Re-Ranking/main.py b/fastreid/evaluation/GPU-Re-Ranking/main.py new file mode 100644 index 000000000..d587fd633 --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/main.py @@ -0,0 +1,62 @@ +""" + Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + + Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang + + Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking + + Paper: https://arxiv.org/abs/2012.07620v2 + + ====================================================================== + + On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms + with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe + that our method achieves comparable or even better retrieval results on the other four + image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, + with limited time cost. +""" + +import os +import torch +import argparse +import numpy as np + +from utils import * +from gnn_reranking import * + +parser = argparse.ArgumentParser(description='Reranking_is_GNN') +parser.add_argument('--data_path', + type=str, + default='../xm_rerank_gpu_2/features/market_88_test.pkl', + help='path to dataset') +parser.add_argument('--k1', + type=int, + default=26, # Market-1501 + # default=60, # Veri-776 + help='parameter k1') +parser.add_argument('--k2', + type=int, + default=7, # Market-1501 + # default=10, # Veri-776 + help='parameter k2') + +args = parser.parse_args() + +def main(): + data = load_pickle(args.data_path) + + query_cam = data['query_cam'] + query_label = data['query_label'] + gallery_cam = data['gallery_cam'] + gallery_label = data['gallery_label'] + + gallery_feature = torch.FloatTensor(data['gallery_f']) + query_feature = torch.FloatTensor(data['query_f']) + query_feature = query_feature.cuda() + gallery_feature = gallery_feature.cuda() + + indices = gnn_reranking(query_feature, gallery_feature, args.k1, args.k2) + evaluate_ranking_list(indices, query_label, query_cam, gallery_label, gallery_cam) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/fastreid/evaluation/GPU-Re-Ranking/utils.py b/fastreid/evaluation/GPU-Re-Ranking/utils.py new file mode 100644 index 000000000..ba350c55a --- /dev/null +++ b/fastreid/evaluation/GPU-Re-Ranking/utils.py @@ -0,0 +1,119 @@ +""" + Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + + Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang + + Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking + + Paper: https://arxiv.org/abs/2012.07620v2 + + ====================================================================== + + On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms + with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe + that our method achieves comparable or even better retrieval results on the other four + image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, + with limited time cost. +""" + +import pickle +import numpy as np +import os +import torch + + +def load_pickle(pickle_path): + with open(pickle_path, 'rb') as f: + data = pickle.load(f) + return data + +def save_pickle(pickle_path, data): + with open(pickle_path, 'wb') as f: + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + +def pairwise_squared_distance(x): + ''' + x : (n_samples, n_points, dims) + return : (n_samples, n_points, n_points) + ''' + x2s = (x * x).sum(-1, keepdim=True) + return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2) + +def pairwise_distance(x, y): + m, n = x.size(0), y.size(0) + + x = x.view(m, -1) + y = y.view(n, -1) + + dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n,m).t() + dist.addmm_(1, -2, x, y.t()) + + return dist + +def cosine_similarity(x, y): + m, n = x.size(0), y.size(0) + + x = x.view(m, -1) + y = y.view(n, -1) + + y = y.t() + score = torch.mm(x, y) + + return score + +def evaluate_ranking_list(indices, query_label, query_cam, gallery_label, gallery_cam): + CMC = np.zeros((len(gallery_label)), dtype=np.int) + ap = 0.0 + + for i in range(len(query_label)): + ap_tmp, CMC_tmp = evaluate(indices[i],query_label[i], query_cam[i], gallery_label, gallery_cam) + if CMC_tmp[0]==-1: + continue + CMC = CMC + CMC_tmp + ap += ap_tmp + + CMC = CMC.astype(np.float32) + CMC = CMC/len(query_label) #average CMC + print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label))) + +def evaluate(index, ql,qc,gl,gc): + query_index = np.argwhere(gl==ql) + camera_index = np.argwhere(gc==qc) + + good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) + junk_index1 = np.argwhere(gl==-1) + junk_index2 = np.intersect1d(query_index, camera_index) + junk_index = np.append(junk_index2, junk_index1) #.flatten()) + + CMC_tmp = compute_mAP(index, good_index, junk_index) + return CMC_tmp + + +def compute_mAP(index, good_index, junk_index): + ap = 0 + cmc = np.zeros((len(index)), dtype=np.int) + if good_index.size==0: # if empty + cmc[0] = -1 + return ap,cmc + + # remove junk_index + mask = np.in1d(index, junk_index, invert=True) + index = index[mask] + + # find good_index index + ngood = len(good_index) + mask = np.in1d(index, good_index) + rows_good = np.argwhere(mask==True) + rows_good = rows_good.flatten() + + cmc[rows_good[0]:] = 1 + for i in range(ngood): + d_recall = 1.0/ngood + precision = (i+1)*1.0/(rows_good[i]+1) + if rows_good[i]!=0: + old_precision = i*1.0/rows_good[i] + else: + old_precision=1.0 + ap = ap + d_recall*(old_precision + precision)/2 + + return ap, cmc \ No newline at end of file