From a4bdaa7243318a37dea6724b33d8d2b373b53842 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Wed, 18 Sep 2024 10:58:04 -0700 Subject: [PATCH] Setup script and cacheing for downloading test data for benchmarking --- .github/workflows/ci.yml | 12 ++++++++++++ scripts/download_test_data.py | 37 +++++++++++++++++++++++++++++++++++ tests/bench.py | 16 +++++++-------- 3 files changed, 57 insertions(+), 8 deletions(-) create mode 100644 scripts/download_test_data.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 509e122f..5c08b61f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,6 +60,18 @@ jobs: with: fetch-depth: 50 # this is to make sure we obtain the target base commit + - uses: actions/cache@v3 + id: cache + with: + path: downloads + key: ${{ hashFiles('scripts/download_samples.py') }} + + - name: Download Samples + if: steps.cache.outputs.cache-hit != 'true' + run: | + pip install requests + python scripts/download_test_data.py + - name: Set up Python uses: actions/setup-python@v5 with: diff --git a/scripts/download_test_data.py b/scripts/download_test_data.py new file mode 100644 index 00000000..3aba4274 --- /dev/null +++ b/scripts/download_test_data.py @@ -0,0 +1,37 @@ +import os +import urllib.request +import zipfile +from pathlib import Path + +ROOT_DIR = Path(__file__).resolve().parents[1] +DATASETS = [ + "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip", + "http://data.celltrackingchallenge.net/training-datasets/PhC-C2DL-PSC.zip" + "http://data.celltrackingchallenge.net/training-datasets/Fluo-N3DH-CE.zip", +] + + +def download_gt_data(url, root_dir): + data_dir = os.path.join(root_dir, "downloads") + + if not os.path.exists(data_dir): + os.mkdir(data_dir) + + filename = url.split("/")[-1] + file_path = os.path.join(data_dir, filename) + + if not os.path.exists(file_path): + urllib.request.urlretrieve(url, file_path) + + # Unzip the data + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(data_dir) + + +def main(): + for url in DATASETS: + download_gt_data(url, ROOT_DIR) + + +if __name__ == "__main__": + main() diff --git a/tests/bench.py b/tests/bench.py index d3d1cd3f..d266b3f2 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -13,24 +13,26 @@ from traccuracy.matchers import CTCMatcher, IOUMatcher from traccuracy.metrics import CTCMetrics, DivisionMetrics -from tests.test_utils import download_gt_data, gt_data - ROOT_DIR = Path(__file__).resolve().parents[1] TIMEOUT = 20 @pytest.fixture(scope="module") def gt_data_2d(): - url = "http://data.celltrackingchallenge.net/training-datasets/PhC-C2DL-PSC.zip" path = "downloads/Fluo-N2DL-HeLa/01_GT/TRA" - return gt_data(url, ROOT_DIR, path) + return load_ctc_data( + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "man_track.txt"), + ) @pytest.fixture(scope="module") def gt_data_3d(): - url = "http://data.celltrackingchallenge.net/training-datasets/Fluo-N3DH-CE.zip" path = "downloads/Fluo-N3DH-CE/01_GT/TRA" - return gt_data(url, ROOT_DIR, path) + return load_ctc_data( + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "man_track.txt"), + ) @pytest.fixture(scope="module") @@ -74,9 +76,7 @@ def test_load_gt_ctc_data( benchmark, dataset, ): - url = f"http://data.celltrackingchallenge.net/training-datasets/{dataset}.zip" path = f"downloads/{dataset}/01_GT/TRA" - download_gt_data(url, ROOT_DIR) benchmark.pedantic( load_ctc_data,