Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests #8

Merged
merged 3 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
name: Lint
name: Lint&Tests

on: [push]
on: [ push, pull_request ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v3

- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.11"

- name: Check code style with Black
uses: psf/black@stable

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt

- name: Check code style with Black
uses: psf/black@stable

- name: Test with pytest
run: |
python -m pytest
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
einops==0.7.0
faiss_gpu==1.7.2
faiss_cpu==1.7.4
fast_pytorch_kmeans==0.2.0.1
geopy==2.4.0
numpy==1.26.1
opencv_python==4.8.1.78
Pillow==10.1.0
prettytable==3.9.0
pytest==7.2.2
pytorch_lightning==2.1.0
pytorch-metric-learning==2.4.1
Requests==2.31.0
Expand Down
Empty file added tests/__init__.py
Empty file.
50 changes: 50 additions & 0 deletions tests/geo_referencers/test_google_maps_referencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import aero_vloc as avl

from aero_vloc.metrics.utils import calculate_distance
from tests.utils import create_localization_pipeline, queries


def test_google_maps_referencer_different_zooms():
"""
Tests GMaps referencer with different zoom levels
"""
for zoom in [1, 1.25, 1.5, 1.75, 2]:
localization_pipeline = create_localization_pipeline(
zoom=zoom,
overlap_level=0.5,
geo_referencer=avl.GoogleMapsReferencer(zoom=17),
)
number_of_tiles = len(localization_pipeline.retrieval_system.sat_map)
localization_results = localization_pipeline(queries, k_closest=number_of_tiles)

test_result = localization_results[0]
lat, lon = test_result
uav_image = queries.uav_images[0]
error = calculate_distance(
lat, lon, uav_image.gt_latitude, uav_image.gt_longitude
)

assert error < 5


def test_google_maps_referencer_different_overlaps():
"""
Tests GMaps referencer with different overlap levels
"""
for overlap in [0, 0.25, 0.5, 0.75]:
localization_pipeline = create_localization_pipeline(
zoom=1.5,
overlap_level=overlap,
geo_referencer=avl.GoogleMapsReferencer(zoom=17),
)
number_of_tiles = len(localization_pipeline.retrieval_system.sat_map)
localization_results = localization_pipeline(queries, k_closest=number_of_tiles)

test_result = localization_results[0]
lat, lon = test_result
uav_image = queries.uav_images[0]
error = calculate_distance(
lat, lon, uav_image.gt_latitude, uav_image.gt_longitude
)

assert error < 5
50 changes: 50 additions & 0 deletions tests/geo_referencers/test_linear_referencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import aero_vloc as avl

from aero_vloc.metrics.utils import calculate_distance
from tests.utils import create_localization_pipeline, queries


def test_linear_referencer_different_zooms():
"""
Tests linear referencer with different zoom levels
"""
for zoom in [1, 1.25, 1.5, 1.75, 2]:
localization_pipeline = create_localization_pipeline(
zoom=zoom,
overlap_level=0.5,
geo_referencer=avl.LinearReferencer(),
)
number_of_tiles = len(localization_pipeline.retrieval_system.sat_map)
localization_results = localization_pipeline(queries, k_closest=number_of_tiles)

test_result = localization_results[0]
lat, lon = test_result
uav_image = queries.uav_images[0]
error = calculate_distance(
lat, lon, uav_image.gt_latitude, uav_image.gt_longitude
)

assert error < 5


def test_linear_referencer_different_overlaps():
"""
Tests linear referencer with different overlap levels
"""
for overlap in [0, 0.25, 0.5, 0.75]:
localization_pipeline = create_localization_pipeline(
zoom=1.5,
overlap_level=overlap,
geo_referencer=avl.LinearReferencer(),
)
number_of_tiles = len(localization_pipeline.retrieval_system.sat_map)
localization_results = localization_pipeline(queries, k_closest=number_of_tiles)

test_result = localization_results[0]
lat, lon = test_result
uav_image = queries.uav_images[0]
error = calculate_distance(
lat, lon, uav_image.gt_latitude, uav_image.gt_longitude
)

assert error < 5
18 changes: 18 additions & 0 deletions tests/homography_estimator/test_homography_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from tests.utils import create_localization_pipeline, homography_estimator, queries


def test_homography_estimator():
"""
Tests homography estimator with a good example from the test dataset
"""
uav_image = queries.uav_images[0]
retrieval_system = create_localization_pipeline().retrieval_system
predictions, matched_kpts_query, matched_kpts_reference = retrieval_system(
uav_image, vpr_k_closest=2, feature_matcher_k_closest=1
)
resize = retrieval_system.feature_matcher.resize
homography_result = homography_estimator(
matched_kpts_query[0], matched_kpts_reference[0], uav_image, resize
)

assert homography_result == (259, 363)
19 changes: 19 additions & 0 deletions tests/index_searchers/test_faiss_searcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import aero_vloc as avl
import numpy as np


def test_faiss_searcher_k_closest():
"""
Checks that the index searcher returns the required number of descriptors
"""
descs_shape = 1024
for number_of_descs in range(50, 1000, 50):
for k_closest in range(1, number_of_descs + 1):
faiss_searcher = avl.FaissSearcher()
descs = np.random.rand(number_of_descs, descs_shape)
faiss_searcher.create(descs)

query_desc = np.random.rand(1, descs_shape)
result = faiss_searcher.search(query_desc, k_closest=k_closest)

assert len(result) == k_closest
34 changes: 34 additions & 0 deletions tests/metrics/test_reference_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import aero_vloc as avl
import numpy as np

from tests.utils import create_localization_pipeline, queries


def test_reference_recall():
"""
Validates the metric using a reasonable level of threshold.
Since one image was taken outside the test map,
the recall value should be equal to 0.5
"""
localization_pipeline = create_localization_pipeline()
recall, mask = avl.reference_recall(
queries, localization_pipeline, k_closest=2, threshold=10
)

assert np.isclose(recall, 0.5)
assert mask == [True, False]


def test_reference_recall_low_threshold():
"""
Validates the metric using a low level of threshold.
System can't locate queries with such accuracy,
so the recall value should be equal to 0
"""
localization_pipeline = create_localization_pipeline()
recall, mask = avl.reference_recall(
queries, localization_pipeline, k_closest=2, threshold=1
)

assert np.isclose(recall, 0)
assert mask == [False, False]
20 changes: 20 additions & 0 deletions tests/metrics/test_retrieval_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import aero_vloc as avl
import numpy as np

from tests.utils import create_localization_pipeline, queries


def test_retrieval_recall():
"""
Validates the metric using a test dataset.
Since one image was taken outside the test map,
the recall value should be equal to 0.5
"""
localization_pipeline = create_localization_pipeline()
retrieval_system = localization_pipeline.retrieval_system
recall, mask = avl.retrieval_recall(
queries, retrieval_system, vpr_k_closest=2, feature_matcher_k_closest=1
)

assert np.isclose(recall, 0.5)
assert mask == [True, False]
Binary file added tests/test_data/map/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/test_data/map/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions tests/test_data/map/map_metadata.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
filename top_left_lat top_left_lon bottom_right_lat bottom_right_lon
0.png 55.5441183336449 38.24398659277347 55.540385063661006 38.25085304785159
1.png 55.5441183336449 38.25085304785159 55.540385063661006 38.25771950292972
Binary file added tests/test_data/queries/0.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/test_data/queries/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions tests/test_data/queries/queries.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
filename lon lat
0.jpg 38.246245 55.542349
1.png 46.042279 38.171986
29 changes: 29 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import aero_vloc as avl

from pathlib import Path

salad = avl.SALAD()
light_glue = avl.LightGlue()
homography_estimator = avl.HomographyEstimator()
queries = avl.UAVSeq(Path("tests/test_data/queries/queries.txt"))


def create_localization_pipeline(
zoom=1, overlap_level=0, geo_referencer=avl.LinearReferencer()
):
"""
Creates localization pipeline based on SALAD place recognition system,
LightGlue keypoint matcher and test satellite map
"""
sat_map = avl.Map(
Path("tests/test_data/map/map_metadata.txt"),
zoom=zoom,
overlap_level=overlap_level,
geo_referencer=geo_referencer,
)
faiss_searcher = avl.FaissSearcher()
retrieval_system = avl.RetrievalSystem(salad, sat_map, light_glue, faiss_searcher)
localization_pipeline = avl.LocalizationPipeline(
retrieval_system, homography_estimator
)
return localization_pipeline
Loading