Skip to content

Commit

Permalink
Setup script and cacheing for downloading test data for benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Sep 18, 2024
1 parent f837bb2 commit a4bdaa7
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 8 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ jobs:
with:
fetch-depth: 50 # this is to make sure we obtain the target base commit

- uses: actions/cache@v3
id: cache
with:
path: downloads
key: ${{ hashFiles('scripts/download_samples.py') }}

- name: Download Samples
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install requests
python scripts/download_test_data.py
- name: Set up Python
uses: actions/setup-python@v5
with:
Expand Down
37 changes: 37 additions & 0 deletions scripts/download_test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import urllib.request
import zipfile
from pathlib import Path

ROOT_DIR = Path(__file__).resolve().parents[1]
DATASETS = [
"http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip",
"http://data.celltrackingchallenge.net/training-datasets/PhC-C2DL-PSC.zip"
"http://data.celltrackingchallenge.net/training-datasets/Fluo-N3DH-CE.zip",
]


def download_gt_data(url, root_dir):
data_dir = os.path.join(root_dir, "downloads")

if not os.path.exists(data_dir):
os.mkdir(data_dir)

filename = url.split("/")[-1]
file_path = os.path.join(data_dir, filename)

if not os.path.exists(file_path):
urllib.request.urlretrieve(url, file_path)

# Unzip the data
with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_dir)


def main():
for url in DATASETS:
download_gt_data(url, ROOT_DIR)


if __name__ == "__main__":
main()
16 changes: 8 additions & 8 deletions tests/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,26 @@
from traccuracy.matchers import CTCMatcher, IOUMatcher
from traccuracy.metrics import CTCMetrics, DivisionMetrics

from tests.test_utils import download_gt_data, gt_data

ROOT_DIR = Path(__file__).resolve().parents[1]
TIMEOUT = 20


@pytest.fixture(scope="module")
def gt_data_2d():
url = "http://data.celltrackingchallenge.net/training-datasets/PhC-C2DL-PSC.zip"
path = "downloads/Fluo-N2DL-HeLa/01_GT/TRA"
return gt_data(url, ROOT_DIR, path)
return load_ctc_data(
os.path.join(ROOT_DIR, path),
os.path.join(ROOT_DIR, path, "man_track.txt"),
)


@pytest.fixture(scope="module")
def gt_data_3d():
url = "http://data.celltrackingchallenge.net/training-datasets/Fluo-N3DH-CE.zip"
path = "downloads/Fluo-N3DH-CE/01_GT/TRA"
return gt_data(url, ROOT_DIR, path)
return load_ctc_data(
os.path.join(ROOT_DIR, path),
os.path.join(ROOT_DIR, path, "man_track.txt"),
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -74,9 +76,7 @@ def test_load_gt_ctc_data(
benchmark,
dataset,
):
url = f"http://data.celltrackingchallenge.net/training-datasets/{dataset}.zip"
path = f"downloads/{dataset}/01_GT/TRA"
download_gt_data(url, ROOT_DIR)

benchmark.pedantic(
load_ctc_data,
Expand Down

0 comments on commit a4bdaa7

Please sign in to comment.