-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) 2023, Ivan Moskalenko, Anastasiia Kornilova | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from uav_loc.vpr_systems.netvlad.netvlad import NetVLAD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (c) 2023, Stephen Hausler, Sourav Garg, Ming Xu, Michael Milford, Tobias Fischer, | ||
# Ivan Moskalenko, Anastasiia Kornilova | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Significant part of our code is based on Patch-NetVLAD repository | ||
# (https://github.com/QVPR/Patch-NetVLAD) | ||
import faiss | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from sklearn.neighbors import NearestNeighbors | ||
|
||
|
||
class NetVLADModule(nn.Module): | ||
"""NetVLAD layer implementation""" | ||
|
||
def __init__( | ||
self, | ||
num_clusters=64, | ||
dim=128, | ||
normalize_input=True, | ||
vladv2=False, | ||
use_faiss=True, | ||
): | ||
""" | ||
Args: | ||
num_clusters : int | ||
The number of clusters | ||
dim : int | ||
Dimension of descriptors | ||
normalize_input : bool | ||
If true, descriptor-wise L2 normalization is applied to input. | ||
vladv2 : bool | ||
If true, use vladv2 otherwise use vladv1 | ||
""" | ||
super().__init__() | ||
self.num_clusters = num_clusters | ||
self.dim = dim | ||
self.alpha = 0 | ||
self.vladv2 = vladv2 | ||
self.normalize_input = normalize_input | ||
self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=vladv2) | ||
self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) | ||
self.use_faiss = use_faiss | ||
|
||
def init_params(self, clsts, traindescs): | ||
if not self.vladv2: | ||
clstsAssign = clsts / np.linalg.norm(clsts, axis=1, keepdims=True) | ||
dots = np.dot(clstsAssign, traindescs.T) | ||
dots.sort(0) | ||
dots = dots[::-1, :] # sort, descending | ||
|
||
self.alpha = (-np.log(0.01) / np.mean(dots[0, :] - dots[1, :])).item() | ||
self.centroids = nn.Parameter(torch.from_numpy(clsts)) | ||
self.conv.weight = nn.Parameter( | ||
torch.from_numpy(self.alpha * clstsAssign).unsqueeze(2).unsqueeze(3) | ||
) | ||
self.conv.bias = None | ||
else: | ||
if not self.use_faiss: | ||
knn = NearestNeighbors(n_jobs=-1) | ||
knn.fit(traindescs) | ||
del traindescs | ||
ds_sq = np.square(knn.kneighbors(clsts, 2)[1]) | ||
del knn | ||
else: | ||
index = faiss.IndexFlatL2(traindescs.shape[1]) | ||
index.add(traindescs) | ||
del traindescs | ||
ds_sq = np.square(index.search(clsts, 2)[1]) | ||
del index | ||
|
||
self.alpha = (-np.log(0.01) / np.mean(ds_sq[:, 1] - ds_sq[:, 0])).item() | ||
self.centroids = nn.Parameter(torch.from_numpy(clsts)) | ||
del clsts, ds_sq | ||
|
||
self.conv.weight = nn.Parameter( | ||
(2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) | ||
) | ||
self.conv.bias = nn.Parameter(-self.alpha * self.centroids.norm(dim=1)) | ||
|
||
def forward(self, x): | ||
N, C = x.shape[:2] | ||
|
||
if self.normalize_input: | ||
x = F.normalize(x, p=2, dim=1) # across descriptor dim | ||
|
||
# soft-assignment | ||
soft_assign = self.conv(x).view(N, self.num_clusters, -1) | ||
soft_assign = F.softmax(soft_assign, dim=1) | ||
|
||
x_flatten = x.view(N, C, -1) | ||
|
||
# calculate residuals to each clusters | ||
vlad = torch.zeros( | ||
[N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device | ||
) | ||
for C in range( | ||
self.num_clusters | ||
): # slower than non-looped, but lower memory usage | ||
residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - self.centroids[ | ||
C : C + 1, : | ||
].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) | ||
residual *= soft_assign[:, C : C + 1, :].unsqueeze(2) | ||
vlad[:, C : C + 1, :] = residual.sum(dim=-1) | ||
|
||
vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization | ||
vlad = vlad.view(x.size(0), -1) # flatten | ||
vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize | ||
|
||
return vlad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) 2023, Stephen Hausler, Sourav Garg, Ming Xu, Michael Milford, Tobias Fischer, | ||
# Ivan Moskalenko, Anastasiia Kornilova | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Significant part of our code is based on Patch-NetVLAD repository | ||
# (https://github.com/QVPR/Patch-NetVLAD) | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision.models as models | ||
|
||
from uav_loc.vpr_systems.netvlad.model.layers import NetVLADModule | ||
|
||
class Flatten(nn.Module): | ||
def forward(self, input_data): | ||
return input_data.view(input_data.size(0), -1) | ||
|
||
|
||
class L2Norm(nn.Module): | ||
def __init__(self, dim=1): | ||
super().__init__() | ||
self.dim = dim | ||
|
||
def forward(self, input_data): | ||
return F.normalize(input_data, p=2, dim=self.dim) | ||
|
||
def get_backend(): | ||
enc_dim = 512 | ||
enc = models.vgg16(weights="IMAGENET1K_V1") | ||
layers = list(enc.features.children())[:-2] | ||
# only train conv5_1, conv5_2, and conv5_3 (leave rest same as Imagenet trained weights) | ||
for layer in layers[:-5]: | ||
for p in layer.parameters(): | ||
p.requires_grad = False | ||
enc = nn.Sequential(*layers) | ||
return enc_dim, enc | ||
|
||
def get_pca_encoding(model, vlad_encoding): | ||
pca_encoding = model.WPCA(vlad_encoding.unsqueeze(-1).unsqueeze(-1)) | ||
return pca_encoding | ||
|
||
def get_model( | ||
encoder, | ||
encoder_dim, | ||
num_clusters, | ||
use_vladv2=False, | ||
append_pca_layer=False, | ||
num_pcs=8192, | ||
): | ||
nn_model = nn.Module() | ||
nn_model.add_module("encoder", encoder) | ||
|
||
net_vlad = NetVLADModule( | ||
num_clusters=num_clusters, dim=encoder_dim, vladv2=use_vladv2 | ||
) | ||
nn_model.add_module("pool", net_vlad) | ||
if append_pca_layer: | ||
netvlad_output_dim = encoder_dim | ||
netvlad_output_dim *= num_clusters | ||
pca_conv = nn.Conv2d( | ||
netvlad_output_dim, num_pcs, kernel_size=(1, 1), stride=1, padding=0 | ||
) | ||
nn_model.add_module( | ||
"WPCA", nn.Sequential(*[pca_conv, Flatten(), L2Norm(dim=-1)]) | ||
) | ||
return nn_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) 2023, Ivan Moskalenko, Anastasiia Kornilova | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
|
||
from pathlib import Path | ||
from PIL import Image | ||
|
||
from uav_loc.utils import transform_image | ||
from uav_loc.vpr_systems.netvlad.model.models_generic import ( | ||
get_backend, | ||
get_model, | ||
get_pca_encoding, | ||
) | ||
from uav_loc.vpr_systems.vpr_system import VPRSystem | ||
|
||
|
||
class NetVLAD(VPRSystem): | ||
""" | ||
Implementation of [NetVLAD](https://github.com/QVPR/Patch-NetVLAD) global localization method. | ||
""" | ||
|
||
def __init__(self, path_to_weights: str, resize: int = 800): | ||
super().__init__() | ||
self.resize = resize | ||
encoder_dim, encoder = get_backend() | ||
|
||
checkpoint = torch.load( | ||
path_to_weights, map_location=lambda storage, loc: storage | ||
) | ||
num_clusters = checkpoint["state_dict"]["pool.centroids"].shape[0] | ||
num_pcs = checkpoint["state_dict"]["WPCA.0.bias"].shape[0] | ||
self.model = get_model( | ||
encoder, | ||
encoder_dim, | ||
num_clusters, | ||
append_pca_layer=True, | ||
num_pcs=num_pcs, | ||
) | ||
self.model.load_state_dict(checkpoint["state_dict"]) | ||
self.model = self.model.to(self.device) | ||
self.model.eval() | ||
|
||
def get_image_descriptor( | ||
self, | ||
image_path: Path, | ||
): | ||
image = Image.open(image_path).convert("RGB") | ||
image = transform_image(image, self.resize)[None, :].to(self.device) | ||
|
||
with torch.no_grad(): | ||
image_encoding = self.model.encoder(image) | ||
vlad_global = self.model.pool(image_encoding) | ||
vlad_global_pca = get_pca_encoding(self.model, vlad_global) | ||
desc = vlad_global_pca.detach().cpu().numpy()[0] | ||
return desc | ||
|
||
get_image_descriptor.__doc__ = VPRSystem.get_image_descriptor.__doc__ |