Skip to content

Commit

Permalink
fix mem issue and speed up whole processing for reco dataset of Synth…
Browse files Browse the repository at this point in the history
…Text and IMGUR5K (mindee#1038)
  • Loading branch information
felixdittrich92 authored Sep 1, 2022
1 parent ea19161 commit 75aa42a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 23 deletions.
32 changes: 30 additions & 2 deletions doctr/datasets/imgur5k.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import glob
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union

import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm

from .datasets import AbstractDataset
Expand Down Expand Up @@ -63,14 +65,26 @@ def __init__(
if not os.path.exists(label_path) or not os.path.exists(img_folder):
raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")

self.data: List[Tuple[Union[Path, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[Union[str, Path], Dict[str, Any]]] = []
self.train = train
np_dtype = np.float32

img_names = os.listdir(img_folder)
train_samples = int(len(img_names) * 0.9)
set_slice = slice(train_samples) if self.train else slice(train_samples, None)

# define folder to write IMGUR5K recognition dataset
reco_folder_name = "IMGUR5K_recognition_train" if self.train else "IMGUR5K_recognition_test"
reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name
reco_folder_path = os.path.join(os.path.dirname(self.root), reco_folder_name)
reco_images_counter = 0

if recognition_task and os.path.isdir(reco_folder_path):
self._read_from_folder(reco_folder_path)
return
elif recognition_task and not os.path.isdir(reco_folder_path):
os.makedirs(reco_folder_path, exist_ok=False)

with open(label_path) as f:
annotation_file = json.load(f)

Expand Down Expand Up @@ -110,9 +124,23 @@ def __init__(
img_path=os.path.join(self.root, img_name), geoms=np.asarray(box_targets, dtype=np_dtype)
)
for crop, label in zip(crops, labels):
self.data.append((crop, dict(labels=[label])))
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
# write data to disk
with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
f.write(label)
tmp_img = Image.fromarray(crop)
tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
reco_images_counter += 1
else:
self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels)))

if recognition_task:
self._read_from_folder(reco_folder_path)

def extra_repr(self) -> str:
return f"train={self.train}"

def _read_from_folder(self, path: str) -> None:
for img_path in glob.glob(os.path.join(path, "*.png")):
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
self.data.append((img_path, dict(labels=[f.read()])))
48 changes: 27 additions & 21 deletions doctr/datasets/synthtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import glob
import os
import pickle
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple

import numpy as np
from PIL import Image
from scipy import io as sio
from tqdm import tqdm

Expand Down Expand Up @@ -56,18 +57,22 @@ def __init__(
**kwargs,
)
self.train = train
self.data: List[Tuple[Union[str, np.ndarray], Dict[str, Any]]] = []
self.data: List[Tuple[str, Dict[str, Any]]] = []
np_dtype = np.float32

# Load mat data
tmp_root = os.path.join(self.root, "SynthText") if self.SHA256 else self.root
pickle_file_name = "SynthText_Reco_train.pkl" if self.train else "SynthText_Reco_test.pkl"
pickle_file_name = "Poly_" + pickle_file_name if use_polygons else pickle_file_name
pickle_path = os.path.join(tmp_root, pickle_file_name)

if recognition_task and os.path.exists(pickle_path):
self._pickle_read(pickle_path)
# define folder to write SynthText recognition dataset
reco_folder_name = "SynthText_recognition_train" if self.train else "SynthText_recognition_test"
reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name
reco_folder_path = os.path.join(tmp_root, reco_folder_name)
reco_images_counter = 0

if recognition_task and os.path.isdir(reco_folder_path):
self._read_from_folder(reco_folder_path)
return
elif recognition_task and not os.path.isdir(reco_folder_path):
os.makedirs(reco_folder_path, exist_ok=False)

mat_data = sio.loadmat(os.path.join(tmp_root, "gt.mat"))
train_samples = int(len(mat_data["imnames"][0]) * 0.9)
Expand Down Expand Up @@ -98,25 +103,26 @@ def __init__(

if recognition_task:
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path[0]), geoms=word_boxes)
with open(pickle_path, "ab+") as f:
for crop, label in zip(crops, labels):
pickle.dump((crop, label), f)
for crop, label in zip(crops, labels):
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
# write data to disk
with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
f.write(label)
tmp_img = Image.fromarray(crop)
tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
reco_images_counter += 1
else:
self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels)))

if recognition_task:
self._pickle_read(pickle_path)
self._read_from_folder(reco_folder_path)

self.root = tmp_root

def extra_repr(self) -> str:
return f"train={self.train}"

def _pickle_read(self, path: str) -> None:
with open(path, "rb") as f:
while True:
try:
crop, label = pickle.load(f)
self.data.append((crop, dict(labels=[label])))
except EOFError:
break
def _read_from_folder(self, path: str) -> None:
for img_path in glob.glob(os.path.join(path, "*.png")):
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
self.data.append((img_path, dict(labels=[f.read()])))

0 comments on commit 75aa42a

Please sign in to comment.