Skip to content

Commit

Permalink
Merge pull request #6 from prime-slam/salad-vpr
Browse files Browse the repository at this point in the history
SALAD
  • Loading branch information
vnmsklnk authored Jan 4, 2024
2 parents a26e22c + bd0b7a4 commit 80f831e
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aero_vloc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
from aero_vloc.primitives import Map, UAVSeq
from aero_vloc.retrieval_system import RetrievalSystem
from aero_vloc.utils import visualize_matches
from aero_vloc.vpr_systems import AnyLoc, CosPlace, EigenPlaces, MixVPR, NetVLAD
from aero_vloc.vpr_systems import AnyLoc, CosPlace, EigenPlaces, MixVPR, NetVLAD, SALAD
1 change: 0 additions & 1 deletion aero_vloc/retrieval_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
for image in tqdm(
sat_map, desc="Calculating of global descriptors for source DB"
):
image.descriptor = self.vpr_system.get_image_descriptor(image.path)
self.global_descs.append(self.vpr_system.get_image_descriptor(image.path))
self.index.create(np.asarray(self.global_descs))

Expand Down
1 change: 1 addition & 0 deletions aero_vloc/vpr_systems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from aero_vloc.vpr_systems.eigenplaces import EigenPlaces
from aero_vloc.vpr_systems.mixvpr import MixVPR
from aero_vloc.vpr_systems.netvlad import NetVLAD
from aero_vloc.vpr_systems.salad import SALAD
from aero_vloc.vpr_systems.vpr_system import VPRSystem
14 changes: 14 additions & 0 deletions aero_vloc/vpr_systems/salad/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2023, Ivan Moskalenko, Anastasiia Kornilova
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from aero_vloc.vpr_systems.salad.salad import SALAD
52 changes: 52 additions & 0 deletions aero_vloc/vpr_systems/salad/salad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2023, Ivan Moskalenko, Anastasiia Kornilova
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from pathlib import Path
from PIL import Image
from torchvision import transforms as tvf

from aero_vloc.utils import transform_image
from aero_vloc.vpr_systems.vpr_system import VPRSystem


class SALAD(VPRSystem):
"""
Wrapper for [SALAD](https://github.com/serizba/salad) VPR method
"""

def __init__(
self,
resize: int = 800,
gpu_index: int = 0,
):
"""
:param resize: The size to which the larger side of the image will be reduced while maintaining the aspect ratio
:param gpu_index: The index of the GPU to be used
"""
super().__init__(gpu_index)
self.resize = resize
self.model = torch.hub.load("serizba/salad", "dinov2_salad")
self.model.eval().to(self.device)

def get_image_descriptor(self, image_path: Path):
image = Image.open(image_path).convert("RGB")
image = transform_image(image, self.resize).to(self.device)
_, h, w = image.shape
h_new, w_new = (h // 14) * 14, (w // 14) * 14
img_cropped = tvf.CenterCrop((h_new, w_new))(image)[None, ...]
with torch.no_grad():
descriptor = self.model(img_cropped)
descriptor = descriptor.cpu().numpy()[0]
return descriptor
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ geopy==2.4.0
numpy==1.26.1
opencv_python==4.8.1.78
Pillow==10.1.0
prettytable==3.9.0
pytorch_lightning==2.1.0
pytorch-metric-learning==2.4.1
Requests==2.31.0
scikit_learn==1.3.2
torch==2.1.0
Expand Down

0 comments on commit 80f831e

Please sign in to comment.