Skip to content

Commit

Permalink
[Fix] unify recognition dataset parts return signature (mindee#1041)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Sep 2, 2022
1 parent 75aa42a commit f9d3d78
Show file tree
Hide file tree
Showing 18 changed files with 70 additions and 45 deletions.
4 changes: 2 additions & 2 deletions doctr/datasets/cord.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(

# List images
tmp_root = os.path.join(self.root, "image")
self.data: List[Tuple[Union[str, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
self.train = train
np_dtype = np.float32
for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))):
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
)
for crop, label in zip(crops, list(text_targets)):
self.data.append((crop, dict(labels=[label])))
self.data.append((crop, label))
else:
self.data.append(
(img_path, dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)))
Expand Down
10 changes: 10 additions & 0 deletions doctr/datasets/datasets/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
class AbstractDataset(_AbstractDataset):
def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
img_name, target = self.data[index]

# Check target
if isinstance(target, dict):
assert "boxes" in target, "Target should contain 'boxes' key"
assert "labels" in target, "Target should contain 'labels' key"
else:
assert isinstance(target, str) or isinstance(
target, np.ndarray
), "Target should be a string or a numpy array"

# Read image
img = (
tensor_from_numpy(img_name, dtype=torch.float32)
Expand Down
10 changes: 10 additions & 0 deletions doctr/datasets/datasets/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
class AbstractDataset(_AbstractDataset):
def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
img_name, target = self.data[index]

# Check target
if isinstance(target, dict):
assert "boxes" in target, "Target should contain 'boxes' key"
assert "labels" in target, "Target should contain 'labels' key"
else:
assert isinstance(target, str) or isinstance(
target, np.ndarray
), "Target should be a string or a numpy array"

# Read image
img = (
tensor_from_numpy(img_name, dtype=tf.float32)
Expand Down
4 changes: 2 additions & 2 deletions doctr/datasets/funsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(

# # List images
tmp_root = os.path.join(self.root, subfolder, "images")
self.data: List[Tuple[Union[str, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))):
# File existence check
if not os.path.exists(os.path.join(tmp_root, img_path)):
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
for crop, label in zip(crops, list(text_targets)):
# filter labels with unknown characters
if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]):
self.data.append((crop, dict(labels=[label])))
self.data.append((crop, label))
else:
self.data.append(
(
Expand Down
4 changes: 2 additions & 2 deletions doctr/datasets/ic03.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
**kwargs,
)
self.train = train
self.data: List[Tuple[Union[str, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
np_dtype = np.float32

# Load xml data
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
for crop, label in zip(crops, labels):
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
self.data.append((crop, dict(labels=[label])))
self.data.append((crop, label))
else:
self.data.append((name.text, dict(boxes=boxes, labels=labels)))

Expand Down
4 changes: 2 additions & 2 deletions doctr/datasets/ic13.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}"
)

self.data: List[Tuple[Union[Path, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
np_dtype = np.float32

img_names = os.listdir(img_folder)
Expand Down Expand Up @@ -94,6 +94,6 @@ def __init__(
if recognition_task:
crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets)
for crop, label in zip(crops, labels):
self.data.append((crop, dict(labels=[label])))
self.data.append((crop, label))
else:
self.data.append((img_path, dict(boxes=box_targets, labels=labels)))
6 changes: 3 additions & 3 deletions doctr/datasets/iiit5k.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import os
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import scipy.io as sio
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
mat_file = "trainCharBound" if self.train else "testCharBound"
mat_data = sio.loadmat(os.path.join(tmp_root, f"{mat_file}.mat"))[mat_file][0]

self.data: List[Tuple[str, Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
np_dtype = np.float32

for img_path, label, box_targets in tqdm(iterable=mat_data, desc="Unpacking IIIT5K", total=len(mat_data)):
Expand All @@ -74,7 +74,7 @@ def __init__(
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}")

if recognition_task:
self.data.append((_raw_path, dict(labels=[_raw_label])))
self.data.append((_raw_path, _raw_label))
else:
if use_polygons:
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
Expand Down
4 changes: 2 additions & 2 deletions doctr/datasets/imgur5k.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
if not os.path.exists(label_path) or not os.path.exists(img_folder):
raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")

self.data: List[Tuple[Union[str, Path], Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
self.train = train
np_dtype = np.float32

Expand Down Expand Up @@ -143,4 +143,4 @@ def extra_repr(self) -> str:
def _read_from_folder(self, path: str) -> None:
for img_path in glob.glob(os.path.join(path, "*.png")):
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
self.data.append((img_path, dict(labels=[f.read()])))
self.data.append((img_path, f.read()))
8 changes: 4 additions & 4 deletions doctr/datasets/mjsynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import os
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple

from tqdm import tqdm

Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
if not os.path.exists(label_path) or not os.path.exists(img_folder):
raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")

self.data: List[Tuple[str, Dict[str, Any]]] = []
self.data: List[Tuple[str, str]] = []
self.train = train

with open(label_path) as f:
Expand All @@ -96,10 +96,10 @@ def __init__(

for path in tqdm(iterable=img_paths[set_slice], desc="Unpacking MJSynth", total=len(img_paths[set_slice])):
if path not in self.BLACKLIST:
label = [path.split("_")[1]]
label = path.split("_")[1]
img_path = os.path.join(img_folder, path[2:]).strip()

self.data.append((img_path, dict(labels=label)))
self.data.append((img_path, label))

def extra_repr(self) -> str:
return f"train={self.train}"
4 changes: 2 additions & 2 deletions doctr/datasets/sroie.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self.train = train

tmp_root = os.path.join(self.root, "images")
self.data: List[Tuple[Union[str, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
np_dtype = np.float32

for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))):
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path), geoms=coords)
for crop, label in zip(crops, labels):
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
self.data.append((crop, dict(labels=[label])))
self.data.append((crop, label))
else:
self.data.append((img_path, dict(boxes=coords, labels=labels)))

Expand Down
4 changes: 2 additions & 2 deletions doctr/datasets/svhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
**kwargs,
)
self.train = train
self.data: List[Tuple[Union[str, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
np_dtype = np.float32

tmp_root = os.path.join(self.root, "train" if train else "test")
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets)
for crop, label in zip(crops, label_targets):
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
self.data.append((crop, dict(labels=[label])))
self.data.append((crop, label))
else:
self.data.append((img_name, dict(boxes=box_targets, labels=label_targets)))

Expand Down
4 changes: 2 additions & 2 deletions doctr/datasets/svt.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
**kwargs,
)
self.train = train
self.data: List[Tuple[Union[str, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
np_dtype = np.float32

# Load xml data
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
for crop, label in zip(crops, labels):
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
self.data.append((crop, dict(labels=[label])))
self.data.append((crop, label))
else:
self.data.append((name.text, dict(boxes=boxes, labels=labels)))

Expand Down
6 changes: 3 additions & 3 deletions doctr/datasets/synthtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import glob
import os
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
**kwargs,
)
self.train = train
self.data: List[Tuple[str, Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
np_dtype = np.float32

# Load mat data
Expand Down Expand Up @@ -125,4 +125,4 @@ def extra_repr(self) -> str:
def _read_from_folder(self, path: str) -> None:
for img_path in glob.glob(os.path.join(path, "*.png")):
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
self.data.append((img_path, dict(labels=[f.read()])))
self.data.append((img_path, f.read()))
1 change: 0 additions & 1 deletion references/recognition/evaluate_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
val_loss, batch_cnt = 0, 0
for images, targets in tqdm(val_loader):
try:
targets = [t["labels"][0] for t in targets]
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)
Expand Down
1 change: 0 additions & 1 deletion references/recognition/evaluate_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def evaluate(model, val_loader, batch_transforms, val_metric):
for images, targets in tqdm(val_iter):
try:
images = batch_transforms(images)
targets = [t["labels"][0] for t in targets]
out = model(images, targets, return_preds=True, training=False)
# Compute metric
if len(out["preds"]):
Expand Down
21 changes: 16 additions & 5 deletions tests/common/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,23 @@ def test_abstractdataset(mock_image_path):
# Check transforms
path = Path(mock_image_path)
ds = datasets.datasets.AbstractDataset(path.parent)
# Check target format
with pytest.raises(AssertionError):
ds.data = [(path.name, 0)]
img, target = ds[0]
with pytest.raises(AssertionError):
ds.data = [(path.name, dict(boxes=np.array([[0, 0, 1, 1]])))]
img, target = ds[0]
with pytest.raises(AssertionError):
ds.data = [(ds.data[0][0], {"label": "A"})]
img, target = ds[0]

# Patch some data
ds.data = [(path.name, 0)]
ds.data = [(path.name, np.array([0]))]

# Fetch the img
img, target = ds[0]
assert isinstance(target, int) and target == 0
assert isinstance(target, np.ndarray) and target == np.array([0])

# Check img_transforms
ds.img_transforms = lambda x: 1 - x
Expand All @@ -44,13 +55,13 @@ def test_abstractdataset(mock_image_path):
assert np.all(img3.numpy() == img.numpy()) and (target3 == (target + 1))

# Check inplace modifications
ds.data = [(ds.data[0][0], {"label": "A"})]
ds.data = [(ds.data[0][0], "A")]

def inplace_transfo(x, target):
target["label"] += "B"
target += "B"
return x, target

ds.sample_transforms = inplace_transfo
_, t = ds[0]
_, t = ds[0]
assert t["label"] == "AB"
assert t == "AB"
10 changes: 4 additions & 6 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,11 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly
def _validate_dataset_recognition_part(ds, input_size, batch_size=2):

# Fetch one sample
img, target = ds[0]
img, label = ds[0]
assert isinstance(img, torch.Tensor)
assert img.shape == (3, *input_size)
assert img.dtype == torch.float32
assert isinstance(target, dict)
assert len(target["labels"]) == 1
assert isinstance(target["labels"][0], str)
assert isinstance(label, str)

# Check batching
loader = DataLoader(
Expand All @@ -68,9 +66,9 @@ def _validate_dataset_recognition_part(ds, input_size, batch_size=2):
collate_fn=ds.collate_fn,
)

images, targets = next(iter(loader))
images, labels = next(iter(loader))
assert isinstance(images, torch.Tensor) and images.shape == (batch_size, 3, *input_size)
assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets)
assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels)


def test_visiondataset():
Expand Down
10 changes: 4 additions & 6 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,18 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly
def _validate_dataset_recognition_part(ds, input_size, batch_size=2):

# Fetch one sample
img, target = ds[0]
img, label = ds[0]
assert isinstance(img, tf.Tensor)
assert img.shape == (*input_size, 3)
assert img.dtype == tf.float32
assert isinstance(target, dict)
assert len(target["labels"]) == 1
assert isinstance(target["labels"][0], str)
assert isinstance(label, str)

# Check batching
loader = DataLoader(ds, batch_size=batch_size)

images, targets = next(iter(loader))
images, labels = next(iter(loader))
assert isinstance(images, tf.Tensor) and images.shape == (batch_size, *input_size, 3)
assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets)
assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels)


def test_detection_dataset(mock_image_folder, mock_detection_label):
Expand Down

0 comments on commit f9d3d78

Please sign in to comment.