-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #62 from Janelia-Trackathon-2023/benchmark
Add basic set of performance benchmarking tests
- Loading branch information
Showing
4 changed files
with
216 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Do not run this workflow on pull request since this workflow has permission to modify contents. | ||
on: | ||
push: | ||
branches: | ||
- main | ||
|
||
permissions: | ||
# deployments permission to deploy GitHub pages website | ||
deployments: write | ||
# contents permission to update benchmark contents in gh-pages branch | ||
contents: write | ||
|
||
jobs: | ||
benchmark: | ||
name: Report benchmarks on gh-pages | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: 3.11 | ||
cache-dependency-path: "pyproject.toml" | ||
cache: "pip" | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install -U pip | ||
python -m pip install -e .[test] | ||
- name: Run benchmark | ||
run: | | ||
pytest tests/bench.py --benchmark-json output.json | ||
- name: Store benchmark results | ||
uses: benchmark-action/github-action-benchmark@v1 | ||
with: | ||
name: Python Benchmark with pytest-benchmark | ||
tool: 'pytest' | ||
output-file-path: output.json | ||
github-token: ${{ secrets.GITHUB_TOKEN }} | ||
auto-push: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import copy | ||
import os | ||
import urllib.request | ||
import zipfile | ||
|
||
import pytest | ||
from traccuracy.loaders import load_ctc_data | ||
from traccuracy.matchers import CTCMatched, IOUMatched | ||
from traccuracy.metrics import CTCMetrics, DivisionMetrics | ||
|
||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
|
||
|
||
def download_gt_data(): | ||
# Download GT data -- look into cacheing this in github actions | ||
url = "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip" | ||
data_dir = os.path.join(ROOT_DIR, "downloads") | ||
|
||
if not os.path.exists(data_dir): | ||
os.mkdir(data_dir) | ||
|
||
filename = url.split("/")[-1] | ||
file_path = os.path.join(data_dir, filename) | ||
|
||
if not os.path.exists(file_path): | ||
urllib.request.urlretrieve(url, file_path) | ||
|
||
# Unzip the data | ||
with zipfile.ZipFile(file_path, "r") as zip_ref: | ||
zip_ref.extractall(data_dir) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def gt_data(): | ||
download_gt_data() | ||
return load_ctc_data( | ||
os.path.join(ROOT_DIR, "downloads/Fluo-N2DL-HeLa/01_GT/TRA"), | ||
os.path.join(ROOT_DIR, "downloads/Fluo-N2DL-HeLa/01_GT/TRA/man_track.txt"), | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def pred_data(): | ||
return load_ctc_data( | ||
os.path.join(ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES"), | ||
os.path.join( | ||
ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt" | ||
), | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def ctc_matched(gt_data, pred_data): | ||
return CTCMatched(gt_data, pred_data) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def iou_matched(gt_data, pred_data): | ||
return IOUMatched(gt_data, pred_data, iou_threshold=0.1) | ||
|
||
|
||
def test_load_gt_data(benchmark): | ||
download_gt_data() | ||
|
||
benchmark.pedantic( | ||
load_ctc_data, | ||
args=( | ||
"downloads/Fluo-N2DL-HeLa/01_GT/TRA", | ||
"downloads/Fluo-N2DL-HeLa/01_GT/TRA/man_track.txt", | ||
), | ||
rounds=1, | ||
iterations=1, | ||
) | ||
|
||
|
||
def test_load_pred_data(benchmark): | ||
benchmark.pedantic( | ||
load_ctc_data, | ||
args=( | ||
os.path.join(ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES"), | ||
os.path.join( | ||
ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt" | ||
), | ||
), | ||
rounds=1, | ||
iterations=1, | ||
) | ||
|
||
|
||
def test_ctc_matched(benchmark, gt_data, pred_data): | ||
benchmark(CTCMatched, gt_data, pred_data) | ||
|
||
|
||
@pytest.mark.timeout(300) | ||
def test_ctc_metrics(benchmark, ctc_matched): | ||
def run_compute(): | ||
return CTCMetrics(copy.deepcopy(ctc_matched)).compute() | ||
|
||
ctc_results = benchmark.pedantic(run_compute, rounds=1, iterations=1) | ||
|
||
assert ctc_results["fn_edges"] == 87 | ||
assert ctc_results["fn_nodes"] == 39 | ||
assert ctc_results["fp_edges"] == 60 | ||
assert ctc_results["fp_nodes"] == 0 | ||
assert ctc_results["ns_nodes"] == 0 | ||
assert ctc_results["ws_edges"] == 51 | ||
|
||
|
||
def test_ctc_div_metrics(benchmark, ctc_matched): | ||
def run_compute(): | ||
return DivisionMetrics(copy.deepcopy(ctc_matched)).compute() | ||
|
||
div_results = benchmark(run_compute) | ||
|
||
assert div_results["Frame Buffer 0"]["False Negative Divisions"] == 18 | ||
assert div_results["Frame Buffer 0"]["False Positive Divisions"] == 30 | ||
assert div_results["Frame Buffer 0"]["True Positive Divisions"] == 76 | ||
|
||
|
||
def test_iou_matched(benchmark, gt_data, pred_data): | ||
benchmark(IOUMatched, gt_data, pred_data, iou_threshold=0.5) | ||
|
||
|
||
def test_iou_div_metrics(benchmark, iou_matched): | ||
def run_compute(): | ||
return DivisionMetrics(copy.deepcopy(iou_matched)).compute() | ||
|
||
div_results = benchmark(run_compute) | ||
|
||
assert div_results["Frame Buffer 0"]["False Negative Divisions"] == 25 | ||
assert div_results["Frame Buffer 0"]["False Positive Divisions"] == 31 | ||
assert div_results["Frame Buffer 0"]["True Positive Divisions"] == 69 |