Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support det+cls+rec online prediction pipeline, improve RecResizeNormForInfer to fix bug in single img predict_rec #435

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ per-file-ignores =
tools/infer/text/parallel/base_predict.py:E402
tools/infer/text/parallel/predict_system.py:E402
tools/infer/text/predict_system.py:E402
tools/infer/text/predict_cls.py:E402
tools/infer/text/predict_rec.py:E402
tools/dataset_converters/convert.py:F401,F403
mindocr/data/transforms/transforms_factory.py:F401,F403
Expand Down
8 changes: 5 additions & 3 deletions mindocr/data/transforms/rec_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,14 +411,16 @@ def __call__(self, data):
# tar_h, tar_w = self.targt_shape
resize_h = self.tar_h

max_wh_ratio = self.tar_w / float(self.tar_h)

if not self.keep_ratio:
assert self.tar_w is not None, "Must specify target_width if keep_ratio is False"
resize_w = self.tar_w # if self.tar_w is not None else resized_h * self.max_wh_ratio
else:
src_wh_ratio = w / float(h)
resize_w = math.ceil(min(src_wh_ratio, max_wh_ratio) * resize_h)
if self.tar_w is not None:
max_wh_ratio = self.tar_w / float(self.tar_h)
resize_w = math.ceil(min(src_wh_ratio, max_wh_ratio) * resize_h)
else:
resize_w = math.ceil(src_wh_ratio * resize_h)
resized_img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interpolation)

# TODO: norm before padding
Expand Down
16 changes: 15 additions & 1 deletion tests/st/test_online_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _gen_text_image(texts=TEXTS_2, boxes=BOXES_2, save_fp="gen_img.jpg"):

det_img_fp = _gen_text_image(save_fp="gen_det_input.jpg")
rec_img_fp = _gen_text_image([TEXTS_2[0]], [BOXES_2[0]], "gen_rec_input.jpg")
cls_img_fp = rec_img_fp


def test_det_infer():
Expand All @@ -55,6 +56,17 @@ def test_det_infer():
assert ret == 0, "Det inference fails"


def test_cls_infer():
algo = "MV3"
cmd = (
f"python tools/infer/text/predict_cls.py --image_dir {cls_img_fp} --cls_algorithm {algo} "
f"--draw_img_save_dir ./infer_test"
)
print(f"Running command: \n{cmd}")
ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr)
assert ret == 0, "Cls inference fails"


def test_rec_infer():
algo = "CRNN"
cmd = (
Expand All @@ -68,10 +80,12 @@ def test_rec_infer():

def test_system_infer():
det_algo = "DB"
cls_algo = "MV3"
rec_algo = "CRNN_CH"
cmd = (
f"python tools/infer/text/predict_system.py --image_dir {det_img_fp} --det_algorithm {det_algo} "
f"--rec_algorithm {rec_algo} --draw_img_save_dir ./infer_test --visualize_output True"
f"--cls_algorithm {cls_algo} --rec_algorithm {rec_algo} "
f"--draw_img_save_dir ./infer_test --visualize_output True"
)
print(f"Running command: \n{cmd}")
ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr)
Expand Down
34 changes: 34 additions & 0 deletions tools/infer/text/config.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to the arg naming in ppocr for low-cost transfer.

https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/tools/infer/utility.py#L114C1-L120C65

Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def create_parser():
# parser.add_argument("--gpu_id", type=int, default=0)

parser.add_argument("--det_model_config", type=str, help="path to det model yaml config") # added
parser.add_argument("--cls_model_config", type=str, help="path to cls model yaml config") # added
parser.add_argument("--rec_model_config", type=str, help="path to rec model yaml config") # added

# params for text detector
Expand Down Expand Up @@ -90,6 +91,39 @@ def create_parser():
parser.add_argument("--use_dilation", type=str2bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")

# params for text direction classification
parser.add_argument(
"--cls_model_dir",
type=str,
help="directory containing the text direction classification model checkpoint best.ckpt, "
"or path to a specific checkpoint file.",
) # determine the network weights
parser.add_argument(
"--cls_batch_mode",
type=str2bool,
default=True,
help="Whether to run text direction classification inference in batch-mode, "
"which is faster but may degrade the accraucy due to padding or resizing to the same shape.",
) # added
parser.add_argument("--cls_batch_num", type=int, default=8)
parser.add_argument("--cls_algorithm", type=str, choices=["MV3"])
parser.add_argument(
"--cls_rotate_thre",
type=float,
default=0.9,
help="Rotate the image when text direction classification score is larger than this threshold.",
)
parser.add_argument(
"--cls_image_shape",
type=str,
default="3, 48, 192",
help="C, H, W for taget image shape. max_wh_ratio=W/H will be used to control the maximum width "
"after 'aspect-ratio-kept' resizing. Set W larger for longer text.",
)
parser.add_argument(
"--cls_label_list", type=str, nargs="+", default=["0", "180"], choices=[["0", "180"], ["0", "90", "180", "270"]]
)

# params for text recognizer
parser.add_argument(
"--rec_algorithm",
Expand Down
9 changes: 6 additions & 3 deletions tools/infer/text/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, task="det", algo="DB", **kwargs):
raise ValueError(f"No postprocess config defined for {algo}. Please check the algorithm name.")
self.rescale_internally = True
self.round = True
elif task == "rec":
elif task in ("rec", "cls"):
# TODO: update character_dict_path and use_space_char after CRNN trained using en_dict.txt released
if algo.startswith("CRNN") or algo.startswith("SVTR"):
# TODO: allow users to input char dict path
Expand All @@ -52,7 +52,10 @@ def __init__(self, task="det", algo="DB", **kwargs):
character_dict_path=dict_path,
use_space_char=False,
)

elif algo.startswith("MV"):
postproc_cfg = dict(
name="ClsPostprocess",
)
else:
raise ValueError(f"No postprocess config defined for {algo}. Please check the algorithm name.")

Expand Down Expand Up @@ -108,6 +111,6 @@ def __call__(self, pred, data=None):
det_res = dict(polys=polys, scores=scores)

return det_res
elif self.task == "rec":
elif self.task in ("rec", "cls"):
output = self.postprocess(pred)
return output
Loading