diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7f8107e..61793ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,18 +1,26 @@ -name: Lint +name: Lint&Tests -on: [push] +on: [ push, pull_request ] jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: "3.10" + - name: Set up Python 3.11 + uses: actions/setup-python@v3 + with: + python-version: "3.11" - - name: Check code style with Black - uses: psf/black@stable - + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Check code style with Black + uses: psf/black@stable + + - name: Test with pytest + run: | + python -m pytest diff --git a/requirements.txt b/requirements.txt index 0475296..683659f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,12 @@ einops==0.7.0 -faiss_gpu==1.7.2 +faiss_cpu==1.7.4 fast_pytorch_kmeans==0.2.0.1 geopy==2.4.0 numpy==1.26.1 opencv_python==4.8.1.78 Pillow==10.1.0 prettytable==3.9.0 +pytest==7.2.2 pytorch_lightning==2.1.0 pytorch-metric-learning==2.4.1 Requests==2.31.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/geo_referencers/test_google_maps_referencer.py b/tests/geo_referencers/test_google_maps_referencer.py new file mode 100644 index 0000000..8268474 --- /dev/null +++ b/tests/geo_referencers/test_google_maps_referencer.py @@ -0,0 +1,50 @@ +import aero_vloc as avl + +from aero_vloc.metrics.utils import calculate_distance +from tests.utils import create_localization_pipeline, queries + + +def test_google_maps_referencer_different_zooms(): + """ + Tests GMaps referencer with different zoom levels + """ + for zoom in [1, 1.25, 1.5, 1.75, 2]: + localization_pipeline = create_localization_pipeline( + zoom=zoom, + overlap_level=0.5, + geo_referencer=avl.GoogleMapsReferencer(zoom=17), + ) + number_of_tiles = len(localization_pipeline.retrieval_system.sat_map) + localization_results = localization_pipeline(queries, k_closest=number_of_tiles) + + test_result = localization_results[0] + lat, lon = test_result + uav_image = queries.uav_images[0] + error = calculate_distance( + lat, lon, uav_image.gt_latitude, uav_image.gt_longitude + ) + + assert error < 5 + + +def test_google_maps_referencer_different_overlaps(): + """ + Tests GMaps referencer with different overlap levels + """ + for overlap in [0, 0.25, 0.5, 0.75]: + localization_pipeline = create_localization_pipeline( + zoom=1.5, + overlap_level=overlap, + geo_referencer=avl.GoogleMapsReferencer(zoom=17), + ) + number_of_tiles = len(localization_pipeline.retrieval_system.sat_map) + localization_results = localization_pipeline(queries, k_closest=number_of_tiles) + + test_result = localization_results[0] + lat, lon = test_result + uav_image = queries.uav_images[0] + error = calculate_distance( + lat, lon, uav_image.gt_latitude, uav_image.gt_longitude + ) + + assert error < 5 diff --git a/tests/geo_referencers/test_linear_referencer.py b/tests/geo_referencers/test_linear_referencer.py new file mode 100644 index 0000000..4feb4db --- /dev/null +++ b/tests/geo_referencers/test_linear_referencer.py @@ -0,0 +1,50 @@ +import aero_vloc as avl + +from aero_vloc.metrics.utils import calculate_distance +from tests.utils import create_localization_pipeline, queries + + +def test_linear_referencer_different_zooms(): + """ + Tests linear referencer with different zoom levels + """ + for zoom in [1, 1.25, 1.5, 1.75, 2]: + localization_pipeline = create_localization_pipeline( + zoom=zoom, + overlap_level=0.5, + geo_referencer=avl.LinearReferencer(), + ) + number_of_tiles = len(localization_pipeline.retrieval_system.sat_map) + localization_results = localization_pipeline(queries, k_closest=number_of_tiles) + + test_result = localization_results[0] + lat, lon = test_result + uav_image = queries.uav_images[0] + error = calculate_distance( + lat, lon, uav_image.gt_latitude, uav_image.gt_longitude + ) + + assert error < 5 + + +def test_linear_referencer_different_overlaps(): + """ + Tests linear referencer with different overlap levels + """ + for overlap in [0, 0.25, 0.5, 0.75]: + localization_pipeline = create_localization_pipeline( + zoom=1.5, + overlap_level=overlap, + geo_referencer=avl.LinearReferencer(), + ) + number_of_tiles = len(localization_pipeline.retrieval_system.sat_map) + localization_results = localization_pipeline(queries, k_closest=number_of_tiles) + + test_result = localization_results[0] + lat, lon = test_result + uav_image = queries.uav_images[0] + error = calculate_distance( + lat, lon, uav_image.gt_latitude, uav_image.gt_longitude + ) + + assert error < 5 diff --git a/tests/homography_estimator/test_homography_estimator.py b/tests/homography_estimator/test_homography_estimator.py new file mode 100644 index 0000000..a1ee019 --- /dev/null +++ b/tests/homography_estimator/test_homography_estimator.py @@ -0,0 +1,18 @@ +from tests.utils import create_localization_pipeline, homography_estimator, queries + + +def test_homography_estimator(): + """ + Tests homography estimator with a good example from the test dataset + """ + uav_image = queries.uav_images[0] + retrieval_system = create_localization_pipeline().retrieval_system + predictions, matched_kpts_query, matched_kpts_reference = retrieval_system( + uav_image, vpr_k_closest=2, feature_matcher_k_closest=1 + ) + resize = retrieval_system.feature_matcher.resize + homography_result = homography_estimator( + matched_kpts_query[0], matched_kpts_reference[0], uav_image, resize + ) + + assert homography_result == (259, 363) diff --git a/tests/index_searchers/test_faiss_searcher.py b/tests/index_searchers/test_faiss_searcher.py new file mode 100644 index 0000000..4fcecfc --- /dev/null +++ b/tests/index_searchers/test_faiss_searcher.py @@ -0,0 +1,19 @@ +import aero_vloc as avl +import numpy as np + + +def test_faiss_searcher_k_closest(): + """ + Checks that the index searcher returns the required number of descriptors + """ + descs_shape = 1024 + for number_of_descs in range(50, 1000, 50): + for k_closest in range(1, number_of_descs + 1): + faiss_searcher = avl.FaissSearcher() + descs = np.random.rand(number_of_descs, descs_shape) + faiss_searcher.create(descs) + + query_desc = np.random.rand(1, descs_shape) + result = faiss_searcher.search(query_desc, k_closest=k_closest) + + assert len(result) == k_closest diff --git a/tests/metrics/test_reference_recall.py b/tests/metrics/test_reference_recall.py new file mode 100644 index 0000000..1ecf6ef --- /dev/null +++ b/tests/metrics/test_reference_recall.py @@ -0,0 +1,34 @@ +import aero_vloc as avl +import numpy as np + +from tests.utils import create_localization_pipeline, queries + + +def test_reference_recall(): + """ + Validates the metric using a reasonable level of threshold. + Since one image was taken outside the test map, + the recall value should be equal to 0.5 + """ + localization_pipeline = create_localization_pipeline() + recall, mask = avl.reference_recall( + queries, localization_pipeline, k_closest=2, threshold=10 + ) + + assert np.isclose(recall, 0.5) + assert mask == [True, False] + + +def test_reference_recall_low_threshold(): + """ + Validates the metric using a low level of threshold. + System can't locate queries with such accuracy, + so the recall value should be equal to 0 + """ + localization_pipeline = create_localization_pipeline() + recall, mask = avl.reference_recall( + queries, localization_pipeline, k_closest=2, threshold=1 + ) + + assert np.isclose(recall, 0) + assert mask == [False, False] diff --git a/tests/metrics/test_retrieval_recall.py b/tests/metrics/test_retrieval_recall.py new file mode 100644 index 0000000..d9ab110 --- /dev/null +++ b/tests/metrics/test_retrieval_recall.py @@ -0,0 +1,20 @@ +import aero_vloc as avl +import numpy as np + +from tests.utils import create_localization_pipeline, queries + + +def test_retrieval_recall(): + """ + Validates the metric using a test dataset. + Since one image was taken outside the test map, + the recall value should be equal to 0.5 + """ + localization_pipeline = create_localization_pipeline() + retrieval_system = localization_pipeline.retrieval_system + recall, mask = avl.retrieval_recall( + queries, retrieval_system, vpr_k_closest=2, feature_matcher_k_closest=1 + ) + + assert np.isclose(recall, 0.5) + assert mask == [True, False] diff --git a/tests/test_data/map/0.png b/tests/test_data/map/0.png new file mode 100644 index 0000000..2cff738 Binary files /dev/null and b/tests/test_data/map/0.png differ diff --git a/tests/test_data/map/1.png b/tests/test_data/map/1.png new file mode 100644 index 0000000..927306c Binary files /dev/null and b/tests/test_data/map/1.png differ diff --git a/tests/test_data/map/map_metadata.txt b/tests/test_data/map/map_metadata.txt new file mode 100644 index 0000000..e0547f9 --- /dev/null +++ b/tests/test_data/map/map_metadata.txt @@ -0,0 +1,3 @@ +filename top_left_lat top_left_lon bottom_right_lat bottom_right_lon +0.png 55.5441183336449 38.24398659277347 55.540385063661006 38.25085304785159 +1.png 55.5441183336449 38.25085304785159 55.540385063661006 38.25771950292972 diff --git a/tests/test_data/queries/0.jpg b/tests/test_data/queries/0.jpg new file mode 100644 index 0000000..b159d81 Binary files /dev/null and b/tests/test_data/queries/0.jpg differ diff --git a/tests/test_data/queries/1.png b/tests/test_data/queries/1.png new file mode 100644 index 0000000..1eb9b7b Binary files /dev/null and b/tests/test_data/queries/1.png differ diff --git a/tests/test_data/queries/queries.txt b/tests/test_data/queries/queries.txt new file mode 100644 index 0000000..1fe13e9 --- /dev/null +++ b/tests/test_data/queries/queries.txt @@ -0,0 +1,3 @@ +filename lon lat +0.jpg 38.246245 55.542349 +1.png 46.042279 38.171986 \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..d04674b --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,29 @@ +import aero_vloc as avl + +from pathlib import Path + +salad = avl.SALAD() +light_glue = avl.LightGlue() +homography_estimator = avl.HomographyEstimator() +queries = avl.UAVSeq(Path("tests/test_data/queries/queries.txt")) + + +def create_localization_pipeline( + zoom=1, overlap_level=0, geo_referencer=avl.LinearReferencer() +): + """ + Creates localization pipeline based on SALAD place recognition system, + LightGlue keypoint matcher and test satellite map + """ + sat_map = avl.Map( + Path("tests/test_data/map/map_metadata.txt"), + zoom=zoom, + overlap_level=overlap_level, + geo_referencer=geo_referencer, + ) + faiss_searcher = avl.FaissSearcher() + retrieval_system = avl.RetrievalSystem(salad, sat_map, light_glue, faiss_searcher) + localization_pipeline = avl.LocalizationPipeline( + retrieval_system, homography_estimator + ) + return localization_pipeline