Skip to content

Commit

Permalink
Import Matched in matchers init and change imports accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Nov 13, 2023
1 parent e16f339 commit 4cf32e0
Show file tree
Hide file tree
Showing 10 changed files with 11 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/traccuracy/matchers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
While we specify ground truth and prediction, it is possible to
write a matching function that matches two arbitrary tracking solutions.
"""
from ._base import Matched
from ._compute_overlap import get_labels_with_overlap
from ._ctc import CTCMatcher
from ._iou import IOUMatcher

__all__ = ["CTCMatcher", "IOUMatcher", "get_labels_with_overlap"]
__all__ = ["CTCMatcher", "IOUMatcher", "get_labels_with_overlap", "Matched"]
2 changes: 1 addition & 1 deletion src/traccuracy/metrics/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from traccuracy.matchers._base import Matched
from traccuracy.matchers import Matched


class Metric(ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/traccuracy/metrics/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ._base import Metric

if TYPE_CHECKING:
from ._base import Matched
from traccuracy.matchers import Matched


class AOGMMetrics(Metric):
Expand Down
2 changes: 1 addition & 1 deletion src/traccuracy/metrics/_track_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ._base import Metric

if TYPE_CHECKING:
from ._base import Matched
from traccuracy.matchers import Matched


def _mapping_to_dict(mapping: List[Tuple[Any, Any]]) -> Dict[Any, List[Any]]:
Expand Down
2 changes: 1 addition & 1 deletion src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from traccuracy import EdgeAttr, NodeAttr

if TYPE_CHECKING:
from traccuracy.matchers._base import Matched
from traccuracy.matchers import Matched

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/traccuracy/track_errors/divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from traccuracy._utils import find_gt_node_matches, find_pred_node_matches

if TYPE_CHECKING:
from traccuracy.matchers._base import Matched
from traccuracy.matchers import Matched

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/test_divisions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from traccuracy import TrackingGraph
from traccuracy.matchers._base import Matched
from traccuracy.matchers import Matched
from traccuracy.metrics._divisions import DivisionMetrics

from tests.test_utils import get_division_graphs
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/test_track_overlap_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import networkx as nx
import pytest
from traccuracy import TrackingGraph
from traccuracy.matchers._base import Matched
from traccuracy.matchers import Matched
from traccuracy.metrics._track_overlap import TrackOverlapMetrics, _mapping_to_dict


Expand Down
2 changes: 1 addition & 1 deletion tests/track_errors/test_ctc_errors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import networkx as nx
import numpy as np
from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph
from traccuracy.matchers._base import Matched
from traccuracy.matchers import Matched
from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors


Expand Down
2 changes: 1 addition & 1 deletion tests/track_errors/test_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pytest
from traccuracy import NodeAttr, TrackingGraph
from traccuracy.matchers._base import Matched
from traccuracy.matchers import Matched
from traccuracy.track_errors.divisions import (
_classify_divisions,
_correct_shifted_divisions,
Expand Down

1 comment on commit 4cf32e0

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE 000ea54 Mean (s) HEAD 4cf32e0 Percent Change
test_load_gt_data 1.2873 2.30911 79.38
test_load_pred_data 1.16527 1.87718 61.09
test_ctc_matched 2.13066 4.28535 101.13
test_ctc_metrics 0.51727 1.00246 93.8
test_ctc_div_metrics 0.28197 0.55798 97.89
test_iou_matched 8.42311 22.0857 162.2
test_iou_div_metrics 0.29306 0.56629 93.23

Please sign in to comment.