-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_e2e.py
70 lines (57 loc) · 2.16 KB
/
eval_e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import editdistance
import numpy as np
import torch
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from tqdm import tqdm
from deepscribe2.datasets import PFADetectionDataModule
from deepscribe2.pipeline import DeepScribePipeline
# download pretrained models!
ARTIFACTS_DIR = "artifacts"
# replace this with wherever you put the data. download_data.sh has an example.
DATA_BASE = "data/DeepScribe_Data_2023-02-04_public"
pfa_datamodule = PFADetectionDataModule(
DATA_BASE, batch_size=10, start_from_one=True, localization_only=False
)
pfa_datamodule.prepare_data()
pfa_datamodule.setup(stage="test")
# can also initialize these from model objects directly.
# run download_artifacts.sh first.
pipeline = DeepScribePipeline.from_checkpoints(
os.path.join(ARTIFACTS_DIR, "trained_detector_public_multiclass.ckpt"),
# os.path.join(ARTIFACTS_DIR, "trained_detector_public.ckpt"),
# classifier_ckpt=os.path.join(ARTIFACTS_DIR, "trained_classifier_public.ckpt"),
score_thresh=0.5,
device="cuda" if torch.cuda.is_available() else "cpu",
)
map_metric = MeanAveragePrecision()
edit_dists = []
failed = 0
with torch.no_grad():
for imgs, targets in tqdm(pfa_datamodule.test_dataloader()):
preds = pipeline(imgs)
preds = [
{
key: entry.cpu() if isinstance(entry, torch.Tensor) else entry
for key, entry in pred.items()
}
for pred in preds
]
map_metric.update(preds, targets)
for pred, targ in zip(preds, targets):
if "ordering" in pred and pred["ordering"] is not None:
ordered_labels = pred["labels"][pred["ordering"]].tolist()
targ_labels = targ["labels"].tolist()
edit_dist = editdistance.eval(
ordered_labels,
targ_labels,
) / len(targ_labels)
edit_dists.append(edit_dist)
else:
failed += 1
# compute edit dist
print(
f"edit dists: {np.median(edit_dists)} / {np.mean(edit_dists)} ({np.std(edit_dists)})"
)
print(f"map: {map_metric.compute()}")
print(f"failed: {failed}")