diff --git a/configs/cls/mobilenetv3/cls_mv3.yaml b/configs/cls/mobilenetv3/cls_mv3.yaml index 2d5d036f0..e38fbb691 100644 --- a/configs/cls/mobilenetv3/cls_mv3.yaml +++ b/configs/cls/mobilenetv3/cls_mv3.yaml @@ -142,3 +142,48 @@ eval: drop_remainder: False max_rowsize: 12 num_workers: 8 + +predict: + deive_target: Ascend + device_id: 0 + max_device_memory: 8GB + amp_level: O2 + mode: 0 + ckpt_load_path: /root/.mindspore/models/cls_mobilenetv3-92db9c58.ckpt + dataset_sink_mode: False + dataset: + type: RecDataset + dataset_root: dir/to/dataset + data_dir: all_images + label_file: val_cls_gt.txt + sample_ratio: 1.0 + shuffle: False + transform_pipeline: + - DecodeImage: + img_mode: BGR + to_float32: False + - Rotate90IfVertical: + threshold: 2.0 + direction: counterclockwise + - ClsLabelEncode: + label_list: *label_list + - RecResizeImg: + image_shape: [48, 192] # H, W + padding: False # aspect ratio will be preserved if true. + - NormalizeImage: + bgr_to_rgb: True + is_hwc: True + mean : [127.0, 127.0, 127.0] + std : [127.0, 127.0, 127.0] + - ToCHWImage: + # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize + output_columns: ['image', 'label'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length + net_input_column_index: [0] # input indices for network forward func in output_columns + label_column_index: [1] # input indices marked as label + + loader: + shuffle: False + batch_size: 8 + drop_remainder: False + max_rowsize: 12 + num_workers: 8 \ No newline at end of file diff --git a/configs/det/dbnet/db_r50_icdar15.yaml b/configs/det/dbnet/db_r50_icdar15.yaml index f6eb56b3b..135083893 100644 --- a/configs/det/dbnet/db_r50_icdar15.yaml +++ b/configs/det/dbnet/db_r50_icdar15.yaml @@ -157,7 +157,12 @@ eval: num_workers: 2 predict: - ckpt_load_path: tmp_det/best.ckpt + deive_target: Ascend + device_id: 0 + max_device_memory: 8GB + amp_level: O2 + mode: 0 + ckpt_load_path: /root/.mindspore/models/dbnet_resnet50-c3a4aa24.ckpt output_save_dir: ./output dataset_sink_mode: False dataset: diff --git a/configs/rec/crnn/crnn_resnet34.yaml b/configs/rec/crnn/crnn_resnet34.yaml index 893e481f5..68c882ec3 100644 --- a/configs/rec/crnn/crnn_resnet34.yaml +++ b/configs/rec/crnn/crnn_resnet34.yaml @@ -150,7 +150,12 @@ eval: num_workers: 8 predict: - ckpt_load_path: ./tmp_rec/best.ckpt + deive_target: Ascend + device_id: 0 + max_device_memory: 8GB + amp_level: O2 + mode: 0 + ckpt_load_path: /root/.mindspore/models/crnn_resnet34-83f37f07.ckpt vis_font_path: tools/utils/simfang.ttf dataset_sink_mode: False dataset: diff --git a/mindocr/infer/__init__.py b/mindocr/infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindocr/infer/classification/__init__.py b/mindocr/infer/classification/__init__.py new file mode 100644 index 000000000..cb61c6ab6 --- /dev/null +++ b/mindocr/infer/classification/__init__.py @@ -0,0 +1,3 @@ +from .cls_infer_node import ClsInferNode +from .cls_post_node import ClsPostNode +from .cls_pre_node import ClsPreNode diff --git a/mindocr/infer/classification/classification.py b/mindocr/infer/classification/classification.py new file mode 100644 index 000000000..48a7358ff --- /dev/null +++ b/mindocr/infer/classification/classification.py @@ -0,0 +1,63 @@ +import logging +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict +from typing import List + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from tools.infer.text.utils import get_ckpt_file +from mindocr.data.transforms import create_transforms, run_transforms +from mindocr.postprocess import build_postprocess +from mindocr.infer.utils.model import MSModel, LiteModel + + +algo_to_model_name = { + "MV3": "cls_mobilenet_v3_small_100_model", +} +logger = logging.getLogger("mindocr") + +class ClsPreprocess(object): + def __init__(self, args): + self.args = args + with open(args.cls_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline) + + def __call__(self, img): + data = {"image": img} + # ZHQ TODO: [1:] ??? + data = run_transforms(img, self.transforms[1:]) + return data + + +class ClsModelMS(MSModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.cls_algorithm] + self.config_path = args.cls_config_path + self._init_model(self.model_name, self.config_path) + + +class ClsModelLite(LiteModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.cls_algorithm] + self.config_path = args.cls_config_path + self._init_model(self.model_name, self.config_path) + +INFER_CLS_MAP = {"MindSporeLite": ClsModelLite, "MindSpore": ClsModelMS} + +class ClsPostProcess(object): + def __init__(self, args): + self.args = args + with open(args.cls_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.postprocessor = build_postprocess(self.yaml_cfg.postprocess) + + def __call__(self, pred): + return self.postprocessor(pred) \ No newline at end of file diff --git a/mindocr/infer/classification/cls_infer_node.py b/mindocr/infer/classification/cls_infer_node.py new file mode 100644 index 000000000..7218561f1 --- /dev/null +++ b/mindocr/infer/classification/cls_infer_node.py @@ -0,0 +1,57 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from infer.classification.classification import INFER_CLS_MAP + + +class ClsInferNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(ClsInferNode, self).__init__(args, msg_queue, tqdm_info) + self.args = args + self.cls_model = None + self.task_type = self.args.task_type + + def init_self_args(self): + with open(self.args.cls_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.batch_size = self.yaml_cfg.predict.loader.batch_size + ClsModel = INFER_CLS_MAP[self.yaml_cfg.predict.backend] + self.cls_model = ClsModel(self.args) + super().init_self_args() + + def process(self, input_data): + """ + Input: + - input_data.data: [np.ndarray], shape:[3,w,h], e.g. [3,48,192] + Output: + - input_data.data: [np.ndarray], shape:[?,2] + """ + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data + data = [np.expand_dims(d, 0) for d in data if len(d.shape) == 3] + data = np.concatenate(data, axis=0) + + preds = [] + for batch_i in range(data.shape[0] // self.batch_size + 1): + start_i = batch_i * self.batch_size + end_i = (batch_i + 1) * self.batch_size + d = data[start_i:end_i] + if d.shape[0] == 0: + continue + pred = self.cls_model([d]) + preds.append(pred[0]) + preds = np.concatenate(preds, axis=0) + input_data.data = {"pred": preds} + self.send_to_next_module(input_data) diff --git a/mindocr/infer/classification/cls_post_node.py b/mindocr/infer/classification/cls_post_node.py new file mode 100644 index 000000000..1c3b140c7 --- /dev/null +++ b/mindocr/infer/classification/cls_post_node.py @@ -0,0 +1,63 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +import cv2 +import numpy as np + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from infer.classification.classification import ClsPostProcess +from tools.infer.text.utils import crop_text_region +from pipeline.data_process.utils.cv_utils import crop_box_from_image + + +class ClsPostNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(ClsPostNode, self).__init__(args, msg_queue, tqdm_info) + self.cls_postprocess = ClsPostProcess(args) + self.task_type = self.args.task_type + self.cls_thresh = 0.9 + + def init_self_args(self): + super().init_self_args() + + def process(self, input_data): + """ + Input: + - input_data.data: [np.ndarray], shape:[?,2] + Output: + - input_data.sub_image_list: [np.ndarray], shape:[1,3,-1,-1], e.g. [1,3,48,192] + - input_data.data = None + or + - input_data.infer_result = [(str, float)] + """ + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data + pred = data["pred"] + output = self.cls_postprocess(pred) + angles = output["angles"] + scores = np.array(output["scores"]).tolist() + + batch = input_data.sub_image_size + if self.task_type.value == TaskType.DET_CLS_REC.value: + sub_images = input_data.sub_image_list + for i in range(batch): + angle, score = angles[i], scores[i] + if "180" == angle and score > self.cls_thresh: + sub_images[i] = cv2.rotate(sub_images[i], cv2.ROTATE_180) + input_data.sub_image_list = sub_images + else: + input_data.infer_result = [(angle, score) for angle, score in zip(angles, scores)] + + input_data.data = None + self.send_to_next_module(input_data) diff --git a/mindocr/infer/classification/cls_pre_node.py b/mindocr/infer/classification/cls_pre_node.py new file mode 100644 index 000000000..f28e4eb6d --- /dev/null +++ b/mindocr/infer/classification/cls_pre_node.py @@ -0,0 +1,41 @@ +import argparse +import os +import time +import sys +import numpy as np + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from infer.classification.classification import ClsPreProcess + + +class ClsPreNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(ClsPreNode, self).__init__(args, msg_queue, tqdm_info) + self.cls_preprocesser = ClsPreProcess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + return {"batch_size": 1} + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + if self.task_type.value == TaskType.REC.value: + image = input_data.frame[0] + data = [self.cls_preprocesser(image)["image"]] + input_data.sub_image_size = 1 + input_data.data = data + self.send_to_next_module(input_data) + else: + sub_image_list = input_data.sub_image_list + data = [self.cls_preprocesser(image)["image"] for split_image in sub_image_list] + input_data.sub_image_size = len(sub_image_list) + input_data.data = data + self.send_to_next_module(input_data) diff --git a/mindocr/infer/common/__init__.py b/mindocr/infer/common/__init__.py new file mode 100644 index 000000000..9bd25705f --- /dev/null +++ b/mindocr/infer/common/__init__.py @@ -0,0 +1,3 @@ +from .collect_node import CollectNode +from .decode_node import DecodeNode +from .handout_node import HandoutNode diff --git a/mindocr/infer/common/collect_node.py b/mindocr/infer/common/collect_node.py new file mode 100644 index 000000000..cd589f80d --- /dev/null +++ b/mindocr/infer/common/collect_node.py @@ -0,0 +1,174 @@ +import os +from collections import defaultdict +from ctypes import c_uint64 +from multiprocessing import Manager + +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.data_process.utils import cv_utils +from pipeline.tasks import TaskType +from pipeline.utils import log, safe_list_writer, visual_utils +from pipeline.datatype import ProcessData, ProfilingData, StopData +from pipeline.framework.module_base import ModuleBase + +RESULTS_SAVE_FILENAME = { + TaskType.DET: "det_results.txt", + TaskType.CLS: "cls_results.txt", + TaskType.REC: "rec_results.txt", + TaskType.DET_REC: "pipeline_results.txt", + TaskType.DET_CLS_REC: "pipeline_results.txt", + TaskType.LAYOUT: "layout_results.txt", +} + + +class CollectNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super().__init__(args, msg_queue, tqdm_info) + self.image_sub_remaining = defaultdict(defaultdict) + self.image_pipeline_res = defaultdict(defaultdict) + self.infer_size = defaultdict(int) + self.image_total = Manager().Value(c_uint64, 0) + self.task_type = args.task_type + self.res_save_dir = args.res_save_dir + self.save_filename = RESULTS_SAVE_FILENAME[TaskType(self.task_type.value)] + + def init_self_args(self): + super().init_self_args() + + def _collect_stop(self, input_data): + self.image_total.value = input_data.image_total + + def _vis_results(self, image_name, image, taskid, data_type): + if self.args.crop_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)): + basename = os.path.basename(image_name) + filename = os.path.join(self.args.crop_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x["points"]).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + crop_list = visual_utils.vis_crop(image, box_list) + for i, crop in enumerate(crop_list): + cv_utils.img_write(filename + "_crop_" + str(i) + ".jpg", crop) + + if self.args.vis_pipeline_save_dir: + basename = os.path.basename(image_name) + filename = os.path.join(self.args.vis_pipeline_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x["points"]).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + text_list = [x["transcription"] for x in self.image_pipeline_res[taskid][image_name]] + box_text = visual_utils.vis_bbox_text(image, box_list, text_list, font_path=self.args.vis_font_path) + cv_utils.img_write(filename + ".jpg", box_text) + + if self.args.vis_det_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)): + basename = os.path.basename(image_name) + filename = os.path.join(self.args.vis_det_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + box_line = visual_utils.vis_bbox(image, box_list, [255, 255, 0], 2) + cv_utils.img_write(filename + ".jpg", box_line) + + # log.info(f"{image_name} is finished.") + + def final_text_save(self): + rst_dict = dict() + for rst in self.image_pipeline_res.values(): + rst_dict.update(rst) + save_filename = os.path.join(self.res_save_dir, self.save_filename) + safe_list_writer(rst_dict, save_filename) + # log.info(f"save infer result to {save_filename} successfully") + + def _collect_results(self, input_data: ProcessData): + taskid = input_data.taskid + if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value): + image_path = input_data.image_path[0] # bs=1 + for result in input_data.infer_result: + if result[-1] > 0.5: + if self.args.result_contain_score: + self.image_pipeline_res[taskid][image_path].append( + {"transcription": result[-2], "points": result[:-2], "score": str(result[-1])} + ) + else: + self.image_pipeline_res[taskid][image_path].append( + {"transcription": result[-2], "points": result[:-2]} + ) + if not input_data.infer_result: + self.image_pipeline_res[taskid][image_path] = [] + elif self.task_type.value == TaskType.DET.value: + image_path = input_data.image_path[0] # bs=1 + self.image_pipeline_res[taskid][image_path] = input_data.infer_result + elif self.task_type.value in (TaskType.REC.value, TaskType.CLS.value): + for image_path, infer_result in zip(input_data.image_path, input_data.infer_result): + self.image_pipeline_res[taskid][image_path] = infer_result + elif self.task_type.value == TaskType.LAYOUT.value: + for infer_result in input_data.infer_result: + image_path = infer_result.pop("image_id") + if image_path in self.image_pipeline_res[taskid]: + self.image_pipeline_res[taskid][image_path].append(infer_result) + else: + self.image_pipeline_res[taskid][image_path] = [infer_result] + else: + raise NotImplementedError("Task type do not support.") + + self._update_remaining(input_data) + + def _update_remaining(self, input_data: ProcessData): + taskid = input_data.taskid + data_type = input_data.data_type + if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value): # with sub image + for idx, image_path in enumerate(input_data.image_path): + if image_path in self.image_sub_remaining[taskid]: + self.image_sub_remaining[taskid][image_path] -= input_data.sub_image_size + if not self.image_sub_remaining[taskid][image_path]: + self.image_sub_remaining[taskid].pop(image_path) + self.infer_size[taskid] += 1 + self._vis_results( + image_path, input_data.frame[idx], taskid, data_type + ) if input_data.frame else ... + else: + remaining = input_data.sub_image_total - input_data.sub_image_size + if remaining: + self.image_sub_remaining[taskid][image_path] = remaining + else: + self.infer_size[taskid] += 1 + self._vis_results( + image_path, input_data.frame[idx], taskid, data_type + ) if input_data.frame else ... + else: # without sub image + for idx, image_path in enumerate(input_data.image_path): + self.infer_size[taskid] += 1 + self._vis_results(image_path, input_data.frame[idx], taskid, data_type) if input_data.frame else ... + + def process(self, input_data): + if isinstance(input_data, ProcessData): + taskid = input_data.taskid + if input_data.taskid not in self.image_sub_remaining.keys(): + self.image_sub_remaining[input_data.taskid] = defaultdict(int) + if input_data.taskid not in self.image_pipeline_res.keys(): + self.image_pipeline_res[input_data.taskid] = defaultdict(list) + self._collect_results(input_data) + if self.infer_size[taskid] == input_data.task_images_num: + self.send_to_next_module({taskid: self.image_pipeline_res[taskid]}) + + elif isinstance(input_data, StopData): + self._collect_stop(input_data) + if input_data.exception: + self.stop_manager.value = True + else: + raise ValueError("unknown input data") + + infer_size_sum = sum(self.infer_size.values()) + if self.image_total.value and infer_size_sum == self.image_total.value: + self.final_text_save() + self.stop_manager.value = True + + def stop(self): + profiling_data = ProfilingData( + module_name=self.module_name, + instance_id=self.instance_id, + process_cost_time=self.process_cost.value, + send_cost_time=self.send_cost.value, + image_total=self.image_total.value, + ) + self.msg_queue.put(profiling_data, block=False) + self.is_stop = True diff --git a/mindocr/infer/common/decode_node.py b/mindocr/infer/common/decode_node.py new file mode 100644 index 000000000..eb78f8bbc --- /dev/null +++ b/mindocr/infer/common/decode_node.py @@ -0,0 +1,51 @@ +import os +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.data_process.utils import cv_utils +from pipeline.utils import log +from pipeline.datatype import StopData +from pipeline.framework.module_base import ModuleBase + + +class DecodeNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super().__init__(args, msg_queue, tqdm_info) + self.cost_time = 0 + self.avail_image_total = 0 + + def init_self_args(self): + super().init_self_args() + + def process(self, input_data): + if isinstance(input_data, StopData): + input_data.image_total = self.avail_image_total + self.send_to_next_module(input_data) + return + + if input_data.skip: + self.send_to_next_module(input_data) + return + + # input contains np.ndarray, not need read again + if len(input_data.frame) == len(input_data.image_path) and len(input_data.frame) > 0: + self.avail_image_total += len(input_data.frame) + self.send_to_next_module(input_data) + else: + img_read, img_path_read = [], [] + for image_path in input_data.image_path: + try: + img_read.append(cv_utils.img_read(image_path)) + img_path_read.append(image_path) + self.avail_image_total += 1 + except ValueError: + log.info(f"{image_path} is unavailable and skipped") + continue + input_data.frame = img_read + input_data.image_path = img_path_read + self.send_to_next_module(input_data) diff --git a/mindocr/infer/common/handout_node.py b/mindocr/infer/common/handout_node.py new file mode 100644 index 000000000..24f2fefab --- /dev/null +++ b/mindocr/infer/common/handout_node.py @@ -0,0 +1,100 @@ +import os + +import cv2 +import numpy as np + +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.data_process.utils import cv_utils +from pipeline.utils import log +from pipeline.datatype import ProcessData, StopData, StopSign +from pipeline.framework.module_base import ModuleBase +from pipeline.datatype.process_data import ProcessData + + +class HandoutNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + tqdm_info["queue_len"] = msg_queue._maxsize + super().__init__(args, msg_queue, tqdm_info) + self.image_total = 0 + + def init_self_args(self): + super().init_self_args() + + def process(self, input_mix_data): + if isinstance(input_mix_data, StopSign): + data = self.process_stop_sign() + self.send_to_next_module(data) + elif isinstance(input_mix_data, np.ndarray): + input_data, info_data = input_mix_data + data = self.process_image_array([input_data]) + data.task_images_num = info_data[0] + data.taskid = info_data[1] + data.data_type = 1 + self.send_to_next_module(data) + elif isinstance(input_mix_data, (tuple, list)): + input_data, info_data = input_mix_data + if len(input_data) == 0: + return + if cv_utils.check_type_in_container(input_data, str): + data = self.process_image_path(input_data) + data.data_type = 0 + elif cv_utils.check_type_in_container(input_data, np.ndarray): + data = self.process_image_array(input_data) + data.data_type = 1 + else: + raise ValueError( + "unknown input data, input_data should be StopSign, or tuple&list contains str or np.ndarray" + ) + data.task_images_num = info_data[0] + data.taskid = info_data[1] + self.send_to_next_module(data) + else: + raise ValueError(f"unknown input data: {type(input_mix_data)}") + + def process_image_path(self, image_path_list): + """ + image_folder: List[str], path to images + """ + # log.info(f"sending {', '.join([os.path.basename(x) for x in image_path_list])} to pipleine") + data = ProcessData(image_path=image_path_list) + self.image_total += len(image_path_list) + return data + + def process_image_array(self, image_array_list): + """ + image_array_list: List[np.ndarray], array of images + """ + frames = [] + array_save_path = [] + image_num = len(image_array_list) + for i in range(image_num): + if self.args.input_array_save_dir: + image_path = os.path.join(self.args.input_array_save_dir, f"input_array_{self.image_total}.jpg") + if len(image_array_list[i].shape) != 3: + log.info(f"image_array_list[{i}] array with shape {image_array_list[i].shape} is invalid") + continue + try: + cv_utils.img_write(image_path, image_array_list[i]) + except cv2.error: + log.info(f"image_array_list[{i}] with shape {image_array_list[i].shape} array is invalid") + continue + log.info(f"sending array(saved at {image_path}) to pipleine") + array_save_path.append(image_path) + else: + array_save_path.append(str(i)) + frames.append(image_array_list[i]) + + self.image_total += 1 + data = ProcessData(frame=frames, image_path=array_save_path) + return data + + def process_stop_sign(self): + # image_total of StopData will be assigned in decode_node + return StopData(skip=True) diff --git a/mindocr/infer/detection/__init__.py b/mindocr/infer/detection/__init__.py new file mode 100644 index 000000000..b78397356 --- /dev/null +++ b/mindocr/infer/detection/__init__.py @@ -0,0 +1,3 @@ +from .det_infer_node import DetInferNode +from .det_post_node import DetPostNode +from .det_pre_node import DetPreNode diff --git a/mindocr/infer/detection/det_infer_node.py b/mindocr/infer/detection/det_infer_node.py new file mode 100644 index 000000000..28b7def15 --- /dev/null +++ b/mindocr/infer/detection/det_infer_node.py @@ -0,0 +1,40 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from infer.detection.detection import INFER_DET_MAP + + +class DetInferNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(DetInferNode, self).__init__(args, msg_queue, tqdm_info) + self.args = args + self.det_model = None + self.task_type = self.args.task_type + + def init_self_args(self): + with open(self.args.det_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + DetModel = INFER_DET_MAP[self.yaml_cfg.predict.backend] + self.det_model = DetModel(self.args) + super().init_self_args() + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data["image"] + pred = self.det_model([data]) + + input_data.data = {"pred": pred, "shape_list": input_data.data["shape_list"]} + + self.send_to_next_module(input_data) diff --git a/mindocr/infer/detection/det_post_node.py b/mindocr/infer/detection/det_post_node.py new file mode 100644 index 000000000..1e8bab3a2 --- /dev/null +++ b/mindocr/infer/detection/det_post_node.py @@ -0,0 +1,67 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +import cv2 +import numpy as np + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from mindocr.infer.detection.detection import DetPostProcess +from tools.infer.text.utils import crop_text_region +from pipeline.data_process.utils.cv_utils import crop_box_from_image + + +class DetPostNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(DetPostNode, self).__init__(args, msg_queue, tqdm_info) + self.det_postprocess = DetPostProcess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + pred = input_data.data["pred"][0] + + data_dict = {"shape_list": input_data.data["shape_list"]} + boxes = self.det_postprocess(pred, data_dict) + + boxes = boxes["polys"][0] + infer_res_list = [] + for box in boxes: + infer_res_list.append(box.tolist()) + + input_data.infer_result = infer_res_list + + if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value): + input_data.sub_image_total = len(infer_res_list) + input_data.sub_image_size = len(infer_res_list) + + image = input_data.frame[0] # bs=1 for det + sub_image_list = [] + for box in infer_res_list: + sub_image = crop_box_from_image(image, np.array(box)) + sub_image_list.append(sub_image) + input_data.sub_image_list = sub_image_list + + input_data.data = None + + if not (self.args.crop_save_dir or self.args.vis_det_save_dir or self.args.vis_pipeline_save_dir): + input_data.frame = None + + if not infer_res_list: + input_data.skip = True + + self.send_to_next_module(input_data) diff --git a/mindocr/infer/detection/det_pre_node.py b/mindocr/infer/detection/det_pre_node.py new file mode 100644 index 000000000..81485656b --- /dev/null +++ b/mindocr/infer/detection/det_pre_node.py @@ -0,0 +1,44 @@ +import argparse +import os +import time +import sys +import numpy as np + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from infer.detection.detection import DetPreprocess + + +class DetPreNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(DetPreNode, self).__init__(args, msg_queue, tqdm_info) + self.det_preprocesser = DetPreprocess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + return {"batch_size": 1} + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + if len(input_data.frame) == 0: + return + + image = input_data.frame[0] # bs = 1 for det + data = self.det_preprocesser({"image": image}) + + if len(data["image"].shape) == 3: + data["image"] = np.expand_dims(data["image"], 0) + data["shape_list"] = np.expand_dims(data["shape_list"], 0) + + if self.task_type.value == TaskType.DET.value and not (self.args.crop_save_dir or self.args.vis_det_save_dir): + input_data.frame = None + + input_data.data = data + + self.send_to_next_module(input_data) diff --git a/mindocr/infer/detection/detection.py b/mindocr/infer/detection/detection.py new file mode 100644 index 000000000..59af291e1 --- /dev/null +++ b/mindocr/infer/detection/detection.py @@ -0,0 +1,71 @@ +import logging +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict +from typing import List + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from tools.infer.text.utils import get_ckpt_file +from mindocr.data.transforms import create_transforms, run_transforms +from mindocr.postprocess import build_postprocess +from mindocr.infer.utils.model import MSModel, LiteModel + + +algo_to_model_name = { + "DB": "dbnet_resnet50", + "DB++": "dbnetpp_resnet50", + "DB_MV3": "dbnet_mobilenetv3", + "DB_PPOCRv3": "dbnet_ppocrv3", + "PSE": "psenet_resnet152", +} +logger = logging.getLogger("mindocr") + +class DetPreprocess(object): + def __init__(self, args): + self.args = args + with open(args.det_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + for transform in self.yaml_cfg.predict.dataset.transform_pipeline: + if "DecodeImage" in transform: + transform["DecodeImage"].update({"keep_ori": True}) + break + self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline) + + def __call__(self, img): + data = {"image": img} + data = run_transforms(img, self.transforms[1:]) + return data + + +class DetModelMS(MSModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.det_algorithm] + self.config_path = args.det_config_path + self._init_model(self.model_name, self.config_path) + + +class DetModelLite(LiteModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.det_algorithm] + self.config_path = args.det_config_path + self._init_model(self.model_name, self.config_path) + +INFER_DET_MAP = {"MindSporeLite": DetModelLite, "MindSpore": DetModelMS} + + +class DetPostProcess(object): + def __init__(self, args): + self.args = args + with open(args.det_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.transforms = build_postprocess(self.yaml_cfg.postprocess) + + def __call__(self, pred, data): + return self.transforms(pred, **data) \ No newline at end of file diff --git a/mindocr/infer/node_config.py b/mindocr/infer/node_config.py new file mode 100644 index 000000000..c417e66a3 --- /dev/null +++ b/mindocr/infer/node_config.py @@ -0,0 +1,89 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from mindocr.infer.detection import DetPreNode, DetInferNode, DetPostNode +from mindocr.infer.recognition import RecPreNode, RecInferNode, RecPostNode +from mindocr.infer.classification import ClsPreNode, ClsInferNode, ClsPostNode +from mindocr.infer.common import HandoutNode, DecodeNode, CollectNode +from pipeline.tasks import TaskType +from pipeline.utils import log + +__all__ = [ + "MODEL_DICT", + "DET_DESC", + "CLS_DESC", + "REC_DESC", + "DET_REC_DESC", + "DET_CLS_REC_DESC", +] + +DET_DESC = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("DetPreNode", "0", 1)), + (("DetPreNode", "0", 1), ("DetInferNode", "0", 1)), + (("DetInferNode", "0", 1), ("DetPostNode", "0", 1)), + (("DetPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +REC_DESC = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("RecPreNode", "0", 1)), + (("RecPreNode", "0", 1), ("RecInferNode", "0", 1)), + (("RecInferNode", "0", 1), ("RecPostNode", "0", 1)), + (("RecPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +CLS_DESC = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("ClsPreNode", "0", 1)), + (("ClsPreNode", "0", 1), ("ClsInferNode", "0", 1)), + (("ClsInferNode", "0", 1), ("ClsPostNode", "0", 1)), + (("ClsPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +DET_REC_DESC = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("DetPreNode", "0", 1)), + (("DetPreNode", "0", 1), ("DetInferNode", "0", 1)), + (("DetInferNode", "0", 1), ("DetPostNode", "0", 1)), + (("DetPostNode", "0", 1), ("RecPreNode", "0", 1)), + (("RecPreNode", "0", 1), ("RecInferNode", "0", 1)), + (("RecInferNode", "0", 1), ("RecPostNode", "0", 1)), + (("RecPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +DET_CLS_REC_DESC = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("DetPreNode", "0", 1)), + (("DetPreNode", "0", 1), ("DetInferNode", "0", 1)), + (("DetInferNode", "0", 1), ("DetPostNode", "0", 1)), + (("DetPostNode", "0", 1), ("ClsPreNode", "0", 1)), + (("ClsPreNode", "0", 1), ("ClsInferNode", "0", 1)), + (("ClsInferNode", "0", 1), ("ClsPostNode", "0", 1)), + (("ClsPostNode", "0", 1), ("RecPreNode", "0", 1)), + (("RecPreNode", "0", 1), ("RecInferNode", "0", 1)), + (("RecInferNode", "0", 1), ("RecPostNode", "0", 1)), + (("RecPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +MODEL_DICT = { + TaskType.DET: DET_DESC, + TaskType.REC: REC_DESC, + TaskType.CLS: CLS_DESC, + TaskType.DET_REC: DET_REC_DESC, + TaskType.DET_CLS_REC: DET_CLS_REC_DESC, + # TaskType.LAYOUT: LAYOUT_DESC # TODO +} + +def processor_initiator(classname): + try: + processor = getattr(os.modules.get(__name__), classname) + except AttributeError as error: + log.error("%s doesn't exist.", classname) + raise error + if isinstance(processor, type): + return processor + raise TypeError("%s doesn't exist.", classname) diff --git a/mindocr/infer/recognition/__init__.py b/mindocr/infer/recognition/__init__.py new file mode 100644 index 000000000..bc416dda5 --- /dev/null +++ b/mindocr/infer/recognition/__init__.py @@ -0,0 +1,3 @@ +from .rec_infer_node import RecInferNode +from .rec_post_node import RecPostNode +from .rec_pre_node import RecPreNode diff --git a/mindocr/infer/recognition/rec_infer_node.py b/mindocr/infer/recognition/rec_infer_node.py new file mode 100644 index 000000000..8bb63f641 --- /dev/null +++ b/mindocr/infer/recognition/rec_infer_node.py @@ -0,0 +1,51 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from infer.recognition.recognition import INFER_REC_MAP + + +class RecInferNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(RecInferNode, self).__init__(args, msg_queue, tqdm_info) + self.args = args + self.rec_model = None + self.task_type = self.args.task_type + + def init_self_args(self): + with open(self.args.rec_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.batch_size = self.yaml_cfg.predict.loader.batch_size + RecModel = INFER_REC_MAP[self.yaml_cfg.predict.backend] + self.rec_model = RecModel(self.args) + super().init_self_args() + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data + data = [np.expand_dims(d, 0) for d in data if len(d.shape) == 3] + data = np.concatenate(data, axis=0) + + preds = [] + for batch_i in range(data.shape[0] // self.batch_size + 1): + start_i = batch_i * self.batch_size + end_i = (batch_i + 1) * self.batch_size + d = data[start_i:end_i] + if d.shape[0] == 0: + continue + pred = self.rec_model([d]) + preds.append(pred[0]) + preds = np.concatenate(preds, axis=0) + input_data.data = {"pred": preds} + self.send_to_next_module(input_data) diff --git a/mindocr/infer/recognition/rec_post_node.py b/mindocr/infer/recognition/rec_post_node.py new file mode 100644 index 000000000..d93f1935f --- /dev/null +++ b/mindocr/infer/recognition/rec_post_node.py @@ -0,0 +1,48 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +import cv2 +import numpy as np + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from infer.recognition.recognition import RecPostProcess +from tools.infer.text.utils import crop_text_region +from pipeline.data_process.utils.cv_utils import crop_box_from_image + + +class RecPostNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(RecPostNode, self).__init__(args, msg_queue, tqdm_info) + self.rec_postprocess = RecPostProcess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data + pred = data["pred"] + output = self.rec_postprocess(pred) + texts = output["texts"] + confs = output["confs"] + if self.task_type.value == TaskType.REC.value: + input_data.infer_result = output["texts"] + else: + for i, (text, conf) in enumerate(zip(texts, confs)): + input_data.infer_result[i].append(text) + input_data.infer_result[i].append(conf) + input_data.data = None + self.send_to_next_module(input_data) diff --git a/mindocr/infer/recognition/rec_pre_node.py b/mindocr/infer/recognition/rec_pre_node.py new file mode 100644 index 000000000..69658195f --- /dev/null +++ b/mindocr/infer/recognition/rec_pre_node.py @@ -0,0 +1,41 @@ +import argparse +import os +import time +import sys +import numpy as np + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from infer.recognition.recognition import RecPreProcess + + +class RecPreNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(RecPreNode, self).__init__(args, msg_queue, tqdm_info) + self.rec_preprocesser = RecPreProcess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + return {"batch_size": 1} + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + if self.task_type.value == TaskType.REC.value: + image = input_data.frame[0] + data = [self.rec_preprocesser(image)["image"]] + input_data.sub_image_size = 1 + input_data.data = data + self.send_to_next_module(input_data) + else: + sub_image_list = input_data.sub_image_list + data = [self.rec_preprocesser(image)["image"] for split_image in sub_image_list] + input_data.sub_image_size = len(sub_image_list) + input_data.data = data + self.send_to_next_module(input_data) diff --git a/mindocr/infer/recognition/recognition.py b/mindocr/infer/recognition/recognition.py new file mode 100644 index 000000000..078d8ef1a --- /dev/null +++ b/mindocr/infer/recognition/recognition.py @@ -0,0 +1,68 @@ +import logging +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict +from typing import List + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from tools.infer.text.utils import get_ckpt_file +from mindocr.data.transforms import create_transforms, run_transforms +from mindocr.postprocess import build_postprocess +from mindocr.infer.utils.model import MSModel, LiteModel + + +algo_to_model_name = { + "CRNN": "crnn_resnet34", + "RARE": "rare_resnet34", + "CRNN_CH": "crnn_resnet34_ch", + "RARE_CH": "rare_resnet34_ch", + "SVTR": "svtr_tiny", + "SVTR_PPOCRv3_CH": "svtr_ppocrv3_ch", +} +logger = logging.getLogger("mindocr") + +class RecPreprocess(object): + def __init__(self, args): + self.args = args + with open(args.rec_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline) + + def __call__(self, img): + data = {"image": img} + # ZHQ TODO: [1:] ??? + data = run_transforms(img, self.transforms[1:]) + return data + + +class RecModelMS(MSModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.rec_algorithm] + self.config_path = args.rec_config_path + self._init_model(self.model_name, self.config_path) + + +class RecModelLite(LiteModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.rec_algorithm] + self.config_path = args.rec_config_path + self._init_model(self.model_name, self.config_path) + +INFER_REC_MAP = {"MindSporeLite": RecModelLite, "MindSpore": RecModelMS} + +class RecPostProcess(object): + def __init__(self, args): + self.args = args + with open(args.rec_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.postprocessor = build_postprocess(self.yaml_cfg.postprocess) + + def __call__(self, pred): + return self.postprocessor(pred) \ No newline at end of file diff --git a/mindocr/infer/utils/model.py b/mindocr/infer/utils/model.py new file mode 100644 index 000000000..a6946a2d4 --- /dev/null +++ b/mindocr/infer/utils/model.py @@ -0,0 +1,118 @@ +import os +from collections import defaultdict +from ctypes import c_uint64 +from multiprocessing import Manager + +from abc import ABCMeta, abstractmethod +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from tools.infer.text.utils import get_ckpt_file +from mindocr.models.builder import build_model +from typing import List + +logger = logging.getLogger("mindocr") + +class BaseModel(metaclass=ABCMeta): + def __init__(self, args) -> None: + self.model = None + self.args = args + self.pretrained = True + self.ckpt_load_path = "" + self.amp_level = "O0" + + @abstractmethod + def __call__(self, inputs: List[np.ndarray]) -> List[np.ndarray]: + pass + + @abstractmethod + def _init_model(self, model_name, config_path): + pass + + +class MSModel(BaseModel): + def __init__(self, args) -> None: + super().__init__(args) + + def _init_model(self, model_name, config_path): + global ms + import mindspore as ms + + self.config_path = config_path + with open(self.config_path, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.ckpt_load_path = self.yaml_cfg.predict.ckpt_load_path + if self.ckpt_load_path is None: + self.pretrained = True + self.ckpt_load_path = None + else: + self.pretrained = False + self.ckpt_load_path = get_ckpt_file(self.ckpt_load_path) + + ms.set_context(device_target=self.yaml_cfg.predict.get("device_target", "Ascend")) + ms.set_context(device_id=self.yaml_cfg.predict.get("device_id", 0)) + ms.set_context(mode=self.yaml_cfg.predict.get("mode", 0)) + if self.yaml_cfg.predict.get("max_device_memory", None): + ms.set_context(max_device_memory=self.yaml_cfg.predict.get("max_device_memory")) + self.amp_level = self.yaml_cfg.predict.get("amp_level", "O0") + if ms.get_context("device_target") == "GPU" and self.amp_level == "O3": + logger.warning( + "Model prediction does not support amp_level O3 on GPU currently." + "The program has switched to amp_level O2 automatically." + ) + self.amp_level = "O2" + self.model = build_model( + model_name, + ckpt_load_path=self.ckpt_load_path, + amp_level=self.amp_level, + ) + self.model.set_train(False) + logger.info( + "Init mindspore model: {}. Model weights loaded from {}".format( + model_name, "pretrained url" if self.pretrained else self.ckpt_load_path + ) + ) + def __call__(self, inputs: List[np.ndarray]) -> List[np.ndarray]: + input_ms = [ms.Tensor.from_numpy(input) for input in inputs] + output = self.model(*input_ms) + outputs = [output.asnumpy()] + return outputs + + +class LiteModel(BaseModel): + def __init__(self, args) -> None: + super().__init__(args) + + def _init_model(self, model_name, config_path): + global mslite + import mindspore_lite as mslite + self.config_path = config_path + with open(self.config_path, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.ckpt_load_path = self.yaml_cfg.predict.ckpt_load_path + context = mslite.Context() + device_target = self.yaml_cfg.predict.get("device_target", "Ascend") + context.target = [device_target.lower()] + if device_target.lower() == "ascend": + context.ascend.device_id = self.yaml_cfg.predict.get("device_id", 0) + elif device_target.lower() == "gpu": + context.gpu.device_id = self.yaml_cfg.predict.get("device_id", 0) + else: + pass + self.model = mslite.Model() + self.model.build_from_file(self.ckpt_load_path, mslite.ModelType.MINDIR, context) + + def __call__(self, inputs: List[np.ndarray]) -> List[np.ndarray]: + model_inputs = self.model.get_inputs() + inputs_shape = [list(input.shape) for input in inputs] + self.model.resize(model_inputs, inputs_shape) + for i, input in enumerate(inputs): + model_inputs[i].set_data_from_numpy(input) + model_outputs = self.model.predict(model_inputs) + outputs = [output.get_data_to_numpy().copy() for output in model_outputs] + return outputs diff --git a/pipeline/__init__.py b/pipeline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pipeline/data_process/utils/cv_utils.py b/pipeline/data_process/utils/cv_utils.py new file mode 100644 index 000000000..54dec7729 --- /dev/null +++ b/pipeline/data_process/utils/cv_utils.py @@ -0,0 +1,72 @@ +import os +from typing import List, Tuple + +import cv2 +import numpy as np + + +def get_hw_of_img(image: np.ndarray): + """ + get the hw of hwc image + """ + if len(image.shape) == 3: + # gbr/rgb + height, width, _ = image.shape + elif len(image.shape) == 2: + # gray + height, width = image.shape + else: + raise TypeError("image is not a image of color/gray") + + return height, width + + +def get_batch_hw_of_img(images: List[np.ndarray]) -> Tuple: + return tuple(get_hw_of_img(img) for img in images) + + +def crop_box_from_image(image, box): + if box.shape != (4, 2): + raise ValueError("shape of crop box must be 4*2") + box = box.astype(np.float32) + img_crop_width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3]))) + img_crop_height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height]]) + m = cv2.getPerspectiveTransform(box, pts_std) + dst_img = cv2.warpPerspective( + image, m, (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_width != 0 and dst_img_height / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + + return dst_img + + +def img_read(path: str): + """ + Read a BGR image. + """ + img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR) + + if img is None: + raise ValueError(f"Error! Cannot load the image of {path}") + + return img + + +def img_write(path: str, img: np.ndarray): + filename = os.path.abspath(path) + cv2.imencode(os.path.splitext(filename)[1], img)[1].tofile(filename) + + +def check_type_in_container(input_data, t, skip_last=False): + if skip_last: + check_data = input_data[:-1] + else: + check_data = input_data + for data in check_data: + if not isinstance(data, t): + return False + else: + return True diff --git a/pipeline/datatype/__init__.py b/pipeline/datatype/__init__.py new file mode 100644 index 000000000..f8469deb2 --- /dev/null +++ b/pipeline/datatype/__init__.py @@ -0,0 +1,3 @@ +from .message_data import ProfilingData, StopSign +from .module_data import ModuleConnectDesc, ModuleDesc, ModuleInitArgs +from .process_data import ProcessData, StopData diff --git a/pipeline/datatype/message_data.py b/pipeline/datatype/message_data.py new file mode 100644 index 000000000..d3a665269 --- /dev/null +++ b/pipeline/datatype/message_data.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + + +@dataclass +class StopSign: + stop: bool = True + + +@dataclass +class ProfilingData: + module_name: str = "" + instance_id: int = "" + device_id: int = 0 + process_cost_time: float = 0.0 + send_cost_time: float = 0.0 + image_total: int = -1 diff --git a/pipeline/datatype/module_data.py b/pipeline/datatype/module_data.py new file mode 100644 index 000000000..d5c82207b --- /dev/null +++ b/pipeline/datatype/module_data.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass, field +from enum import Enum + + +class ConnectType(Enum): + MODULE_CONNECT_ONE = 0 + MODULE_CONNECT_CHANNEL = 1 + MODULE_CONNECT_PAIR = 2 + MODULE_CONNECT_RANDOM = 3 + + +@dataclass +class ModuleOutputInfo: + module_name: str + connect_type: ConnectType + output_queue_list_size: int + output_queue_list: list = field(default_factory=lambda: []) + + +@dataclass +class ModuleInitArgs: + pipeline_name: str + module_type: str + module_name: str + instance_id: -1 + + +@dataclass +class ModuleDesc: + module_type: str # 节点类型,如HandoutNode + module_name: str # 节点名,如1,该节点唯一标识为 {module_type}{model_name} + module_count: int + + +@dataclass +class ModuleConnectDesc: + module_send_name: str + module_recv_name: str + connect_type: ConnectType = field(default_factory=lambda: ConnectType.MODULE_CONNECT_RANDOM) + + +@dataclass +class ModulesInfo: + module_list: list = field(default_factory=lambda: []) + input_queue_list: list = field(default_factory=lambda: []) diff --git a/pipeline/datatype/process_data.py b/pipeline/datatype/process_data.py new file mode 100644 index 000000000..b615572cf --- /dev/null +++ b/pipeline/datatype/process_data.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Union + +import numpy as np + + +@dataclass +class ProcessData: + # skip each compute node + skip: bool = False + # prediction results of each image + infer_result: list = field(default_factory=lambda: []) + + # image basic info + image_path: List[str] = field(default_factory=lambda: []) + frame: List[np.ndarray] = field(default_factory=lambda: []) + + # sub image of detection box, for det (+ cls) + rec + sub_image_total: int = 0 # len(sub_image_list_0) + len(sub_image_list_1) + ... + sub_image_list: list = field(default_factory=lambda: []) + sub_image_size: int = 0 # len of sub_image_list + + # data for preprocess -> infer -> postprocess + data: Union[np.ndarray, List[np.ndarray], Dict] = None + + # confidence of the result from rec + score: float = field(default_factory=lambda: []) + + # the images fed into the ocr system in the same call, share the same taskid + taskid: int = 0 + + # number of images shared the same taskid + task_images_num: int = 0 + + # data type: raw input is string path or np.ndarray. 0: string path, 1: np.ndarray + data_type: int = 0 + + +@dataclass +class StopData: + skip: bool = True + image_total: int = 0 + exception: bool = False diff --git a/pipeline/framework/module_base.py b/pipeline/framework/module_base.py new file mode 100644 index 000000000..4a6d71b9a --- /dev/null +++ b/pipeline/framework/module_base.py @@ -0,0 +1,129 @@ +import os +import tqdm +import sys +import time +from abc import abstractmethod +from ctypes import c_longdouble +from multiprocessing import Manager + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.datatype import ModuleInitArgs, ProfilingData +from pipeline.datatype import StopData, StopSign +from pipeline.utils import log, safe_div + + +class ModuleBase(object): + def __init__(self, args, msg_queue, tqdm_info): + self.args = args + self.pipeline_name = "" + self.module_name = "" + self.without_input_queue = False + self.instance_id = 0 + self.is_stop = False + self.msg_queue = msg_queue + self.input_queue = None + self.output_queue = None + self.send_cost = Manager().Value(typecode=c_longdouble, value=0) + self.process_cost = Manager().Value(typecode=c_longdouble, value=0) + self.display_id = tqdm_info["i"] + self.bar = tqdm.tqdm(total=tqdm_info["queue_len"], + desc=f"{self.display_id}. {self.module_name}", + position=self.display_id, + leave=False, + bar_format="{l_bar}{bar}|{n_fmt}/{total_fmt}", + ncols=100) + + def assign_init_args(self, init_args: ModuleInitArgs): + self.pipeline_name = init_args.pipeline_name + self.module_name = init_args.module_name + self.instance_id = init_args.instance_id + + def process_handler(self, stop_manager, module_params, input_queue, output_queue): + self.input_queue = input_queue + self.output_queue = output_queue + self.stop_manager = stop_manager + self.queue_num = 0 + + try: + params = self.init_self_args() + if params: + module_params.update(**params) + except Exception as error: + log.error(f"{self.__class__.__name__} init failed: {error}") + raise error + + # waiting for init sign + while not self.msg_queue.full(): + continue + + # waiting for the release of stop sign + while self.stop_manager.value: + continue + + process_num = 0 + + while True: + time.sleep(self.args.node_fetch_interval) + if self.stop_manager.value: + break + if self.input_queue.empty(): + continue + + process_num += 1 + data = self.input_queue.get(block=True) + qsize = self.input_queue.qsize() + delta = qsize - self.queue_num + self.bar.update(delta) + self.queue_num = qsize + self.bar.set_description(f"{self.display_id}. Node:{self.module_name}, Has Processed:{process_num}, \ + input queue:{qsize}") + self.call_process(data) + self.bar.close() + + def call_process(self, send_data=None): + if send_data is not None or self.without_input_queue: + start_time = time.time() + try: + self.process(send_data) + except Exception as error: + self.process(StopData(exception=True)) + image_path = [os.path.basename(filename) for filename in send_data.image_path] + log.exception(f"ERROR occurred in {self.module_name} module for {', '.join(image_path)}: {error}.") + + cost_time = time.time() - start_time + self.process_cost.value += cost_time + + @abstractmethod + def process(self, input_data): + pass + + @abstractmethod + def init_self_args(self): + self.msg_queue.put(f"{self.__class__.__name__} instance id {self.instance_id} init complete") + log.info(f"{self.__class__.__name__} instance id {self.instance_id} init complete") + + def send_to_next_module(self, output_data): + if self.is_stop: + return + start_time = time.time() + self.output_queue.put(output_data, block=True) + cost_time = time.time() - start_time + self.send_cost.value += cost_time + + def get_module_name(self): + return self.module_name + + def get_instance_id(self): + return self.instance_id + + def stop(self): + profiling_data = ProfilingData( + module_name=self.module_name, + instance_id=self.instance_id, + process_cost_time=self.process_cost.value, + send_cost_time=self.send_cost.value, + ) + self.msg_queue.put(profiling_data, block=False) + self.is_stop = True diff --git a/pipeline/framework/module_manager.py b/pipeline/framework/module_manager.py new file mode 100644 index 000000000..1aa563002 --- /dev/null +++ b/pipeline/framework/module_manager.py @@ -0,0 +1,160 @@ +import os +import tqdm +import sys + +from collections import defaultdict, namedtuple +from ctypes import c_bool +from multiprocessing import Manager, Process, Queue + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.datatype.module_data import ModuleInitArgs, ModulesInfo +from pipeline.utils import log +from mindocr.infer.node_config import processor_initiator + +OutputRegisterInfo = namedtuple("OutputRegisterInfo", ["pipeline_name", "module_send", "module_recv"]) + + +class ModuleManager: + MODULE_QUEUE_MAX_SIZE = 16 + + def __init__(self, msg_queue: Queue, task_queue: Queue, result_queue: Queue, args): + self.pipeline_map = defaultdict(lambda: defaultdict(ModulesInfo)) + self.msg_queue = msg_queue + self.stop_manager = Manager().Value(c_bool, True) + self.args = args + self.pipeline_name = "" + self.process_list = [] + self.queue_list = [] + self.pipeline_queue_map = defaultdict(lambda: defaultdict(list)) + self.task_queue = task_queue # input_queue for HandoutNode + self.result_queue = result_queue # output_queue for CollectNode + self.module_params = Manager().dict() + + @staticmethod + def stop_module(module): + module.stop() + + @staticmethod + def init_module_instance(module_instance, instance_id, pipeline_name, module_type, module_name): + init_args = ModuleInitArgs(pipeline_name=pipeline_name, + module_name=module_name, + module_type=module_type, + instance_id=instance_id) + module_instance.assign_init_args(init_args) + + def register_modules(self, pipeline_name: str, module_desc_list: list, default_count: int): + log.info("----------------------------------------------------") + log.info("---------------register_modules start---------------") + modules_info_dict = self.pipeline_map[pipeline_name] + self.pipeline_name = pipeline_name + + for i, module_desc in enumerate(module_desc_list): + log.info("+++++++++++++++++++++++++++++++++++++") + log.info(module_desc) + log.info("+++++++++++++++++++++++++++++++++++++") + module_count = default_count if module_desc.module_count == -1 else module_desc.module_count + module_info = ModulesInfo() + for instance_id in range(module_count): + tqdm_info = {"i": i, "queue_len": self.MODULE_QUEUE_MAX_SIZE} + module_instance = processor_initiator(module_desc.module_type)(self.args, self.msg_queue, tqdm_info) + self.init_module_instance(module_instance, + instance_id, + pipeline_name, + module_desc.module_type, + module_desc.module_name) + + module_info.module_list.append(module_instance) + modules_info_dict[module_desc.module_name] = module_info + + self.pipeline_map[pipeline_name] = modules_info_dict + + log.info("----------------register_modules end---------------") + log.info("----------------------------------------------------") + + def register_module_connects(self, pipeline_name: str, connect_desc_list: list): + if pipeline_name not in self.pipeline_map: + return + + log.info("----------------------------------------------------") + log.info("-----------register_module_connects start-----------") + + modules_info_dict = self.pipeline_map[pipeline_name] + connect_info_dict = self.pipeline_queue_map[pipeline_name] + last_module = None + for connect_desc in connect_desc_list: + send_name = connect_desc.module_send_name + recv_name = connect_desc.module_recv_name + log.info("+++++++++++++++++++++++++++++++++++++") + log.info(f"Add Connection Between {send_name} And {recv_name}") + log.info("+++++++++++++++++++++++++++++++++++++") + + if send_name not in modules_info_dict: + raise ValueError(f"cannot find send module {send_name}") + + if recv_name not in modules_info_dict: + raise ValueError(f"cannot find receive module {recv_name}") + + queue = Queue(self.MODULE_QUEUE_MAX_SIZE) + connect_info_dict[send_name].append(queue) + connect_info_dict[recv_name].append(queue) + last_module = recv_name + connect_info_dict[last_module].append(self.result_queue) + + log.info("------------register_module_connects end------------") + log.info("----------------------------------------------------") + + def run_pipeline(self): + log.info("-------------- start pipeline-----------------------") + log.info("----------------------------------------------------") + + for pipeline_name in self.pipeline_map.keys(): + modules_info_dict = self.pipeline_map[pipeline_name] + connect_info_dict = self.pipeline_queue_map[pipeline_name] + for module_name in modules_info_dict.keys(): + queue_list = connect_info_dict[module_name] + if len(queue_list) == 1: + input_queue = self.task_queue + output_queue = queue_list[0] + else: + input_queue = queue_list[0] + output_queue = queue_list[1] + + for module in modules_info_dict[module_name].module_list: + self.process_list.append( + Process( + target=module.process_handler, + args=(self.stop_manager, self.module_params, input_queue, output_queue), + daemon=True, + ) + ) + + for process in self.process_list: + process.start() + + def deinit_pipeline_module(self): + # the empty() is not reliable, double check the msg queue is empty for receive the profiling data + while not self.msg_queue.empty(): + self.msg_queue.get() + + for queue in self.queue_list: + while not queue.empty(): + queue.get(block=False) + queue.close() + queue.join_thread() + + # send the profiling data + for pipeline_name in self.pipeline_map.keys(): + modules_info_dict = self.pipeline_map[pipeline_name] + for module_name in modules_info_dict.keys(): + for module in modules_info_dict[module_name].module_list: + self.stop_module(module=module) + + # release all resource + for process in self.process_list: + if process.is_alive(): + process.kill() + + log.info("------------------pipeline stopped------------------") + log.info("----------------------------------------------------") diff --git a/pipeline/framework/pipeline_manager.py b/pipeline/framework/pipeline_manager.py new file mode 100644 index 000000000..348a569e3 --- /dev/null +++ b/pipeline/framework/pipeline_manager.py @@ -0,0 +1,153 @@ +import argparse +import os +import time +import sys +from collections import defaultdict +from multiprocessing import Manager, Process, Queue + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.utils import log, safe_div +from pipeline.datatype import ModuleConnectDesc, ModuleDesc +from pipeline.datatype import StopData, StopSign +from pipeline.framework.module_manager import ModuleManager +from pipeline.tasks import SUPPORTED_TASK_BASIC_MODULE, TaskType +# ZHQ TODO +from mindocr.infer.node_config import MODEL_DICT as MODEL_DICT + + +class ParallelPipelineManager: + TASK_QUEUE_SIZE = 32 + + def __init__(self, args: argparse.Namespace): + self.args = args + self.input_queue = Queue(self.TASK_QUEUE_SIZE) + self.result_queue = Queue(self.TASK_QUEUE_SIZE) + self.process = Process(target=self._build_pipeline_kernel) + self.module_params = Manager().dict() + + def start_pipeline(self): + self.process.start() + self.input_queue.get(block=True) + + def stop_pipeline(self): + self.input_queue.put(StopSign(), block=True) + self.process.join() + self.process.close() + + def fetch_result(self): + if not self.result_queue.empty(): + rst_data = self.result_queue.get(block=True) + else: + rst_data = None + return rst_data + + def pipeline_graph(self, task_type): + module_order = SUPPORTED_TASK_BASIC_MODULE[TaskType(task_type.value)] + module_desc_names_set = set() + module_desc_list = [] + module_connect_desc_list = [] + + for model_name in module_order: + model_name = model_name + for edge in MODEL_DICT.get(model_name, []): + # Add Node + src_node_info, tgt_node_info = edge + src_node_name = src_node_info[0] + src_node_info[1] + if src_node_name not in module_desc_names_set: + module_desc_list.append(ModuleDesc(src_node_info[0], src_node_name, src_node_info[2])) + module_desc_names_set.add(src_node_name) + tgt_node_name = tgt_node_info[0] + tgt_node_info[1] + if tgt_node_name not in module_desc_names_set: + module_desc_list.append(ModuleDesc(tgt_node_info[0], tgt_node_name, tgt_node_info[2])) + module_desc_names_set.add(tgt_node_name) + module_connect_desc_list.append( + ModuleConnectDesc(src_node_name, tgt_node_name) + ) + module_size = sum(desc.module_count for desc in module_desc_list) + log.info(f"module_size: {module_size}") + return module_order, module_size, module_desc_list, module_connect_desc_list + + + def _build_pipeline_kernel(self): + """ + build and register pipeline + """ + task_type = self.args.task_type + + module_order, module_size, module_desc_list, module_connect_desc_list = self.pipeline_graph(task_type) + + msg_queue = Queue(module_size) + + manager = ModuleManager(msg_queue, self.input_queue, self.result_queue, self.args) + manager.register_modules(str(os.getpid()), module_desc_list, 1) + manager.register_module_connects(str(os.getpid()), module_connect_desc_list) + + # start the pipeline, init start + manager.run_pipeline() + + # waiting for task receive + while not msg_queue.full() or len(manager.module_params) != len(module_order): + time.sleep(0.1) + continue + + for _ in range(module_size): + msg_queue.get() + + self.module_params.update(**manager.module_params) + + # send sign for blocking input queue + self.input_queue.put(StopSign(), block=True) + + manager.stop_manager.value = False + + start_time = time.time() + + while not manager.stop_manager.value: + time.sleep(self.args.node_fetch_interval) + + cost_time = time.time() - start_time + + manager.deinit_pipeline_module() + # collect the profiling data + profiling_data = defaultdict(lambda: [0, 0]) + image_total = 0 + for _ in range(module_size): + msg_info = msg_queue.get() + profiling_data[msg_info.module_name][0] += msg_info.process_cost_time + profiling_data[msg_info.module_name][1] += msg_info.send_cost_time + if msg_info.module_name != -1: + image_total = msg_info.image_total + if image_total > 0: + self.profiling(profiling_data, image_total) + perf_info = ( + f"Number of images: {image_total}, " + f"total cost {cost_time:.2f}s, FPS: " + f"{safe_div(image_total, cost_time):.2f}" + ) + print(perf_info) + log.info(perf_info) + + msg_queue.close() + msg_queue.join_thread() + + def profiling(self, profiling_data, image_total): + e2e_cost_time_per_image = 0 + for module_name in profiling_data: + data = profiling_data[module_name] + total_time = data[0] + process_time = data[0] - data[1] + send_time = data[1] + process_avg = safe_div(process_time * 1000, image_total) + e2e_cost_time_per_image += process_avg + log.info( + f"{module_name} cost total {total_time:.2f} s, process avg cost {process_avg:.2f} ms, " + f"send waiting time avg cost {safe_div(send_time * 1000, image_total):.2f} ms" + ) + log.info("----------------------------------------------------") + log.info(f"e2e cost time per image {e2e_cost_time_per_image}ms") + + def __del__(self): + if hasattr(self, "process") and self.process: + self.process.close() diff --git a/pipeline/infer.py b/pipeline/infer.py new file mode 100644 index 000000000..029d0951c --- /dev/null +++ b/pipeline/infer.py @@ -0,0 +1,20 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../"))) + +from pipeline import infer_args # noqa +from pipeline.parallel_pipeline import ParallelPipeline # noqa + + +def main(): + args = infer_args.get_args() + parallel_pipeline = ParallelPipeline(args) + parallel_pipeline.start_pipeline() + parallel_pipeline.infer_for_images(args.input_images_dir, task_id=0) + parallel_pipeline.stop_pipeline() + + +if __name__ == "__main__": + main() diff --git a/pipeline/infer_args.py b/pipeline/infer_args.py new file mode 100644 index 000000000..216b690eb --- /dev/null +++ b/pipeline/infer_args.py @@ -0,0 +1,266 @@ +import argparse +import itertools +import os +import sys +import yaml + +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../"))) + +from pipeline.tasks import TaskType +from pipeline.utils import get_config_by_name_for_model, log, save_path_init + + +def str2bool(v): + return v.lower() in ("true", "t", "1") + + +def get_args(): + """ + command line parameters for inference + """ + parser = argparse.ArgumentParser(description="Arguments for inference.") + parser.add_argument( + "--input_images_dir", + type=str, + required=True, + help="Image or folder path for inference", + ) + + parser.add_argument( + "--node_fetch_interval", + type=float, + default=0.001, + required=False, + help="Interval(seconds) that each node fetch data from queue.", + ) + + parser.add_argument( + "--result_contain_score", + type=bool, + default=False, + required=False, + help="If save confidence score to output result.", + ) + + parser.add_argument( + "--det_algorithm", + type=str, + default="DB++", + choices=["DB", "DB++", "DB_MV3", "DB_PPOCRv3", "PSE"], + help="detection algorithm.", + ) # determine the network architecture + parser.add_argument( + "--det_model_name_or_config", type=str, required=False, help="Detection model name or config file path." + ) + + parser.add_argument( + "--cls_algorithm", + type=str, + default="MV3", + choices=["MV3"], + help="classification algorithm.", + ) # determine the network architecture + parser.add_argument( + "--cls_model_name_or_config", type=str, required=False, help="Classification model name or config file path." + ) + + parser.add_argument( + "--rec_algorithm", + type=str, + default="CRNN", + choices=["CRNN", "RARE", "CRNN_CH", "RARE_CH", "SVTR", "SVTR_PPOCRv3_CH"], + help="recognition algorithm", + ) + parser.add_argument( + "--rec_model_name_or_config", type=str, required=False, help="Recognition model name or config file path." + ) + + parser.add_argument( + "--character_dict_path", type=str, required=False, help="Character dict file path for recognition models." + ) + + # ZHQ TODO + parser.add_argument( + "--layout_model_name_or_config", type=str, required=False, help="Layout model name or config file path." + ) + + parser.add_argument( + "--res_save_dir", + type=str, + default="inference_results", + required=False, + help="Saving dir for inference results.", + ) + + parser.add_argument( + "--input_array_save_dir", + type=str, + required=False, + help="Saving input array.", + ) + + parser.add_argument( + "--vis_det_save_dir", type=str, required=False, help="Saving dir for visualization of detection results." + ) + parser.add_argument( + "--vis_pipeline_save_dir", + type=str, + required=False, + help="Saving dir for visualization of det+cls(optional)+rec pipeline inference results.", + ) + parser.add_argument( + "--vis_font_path", + type=str, + default="docs/fonts/simfang.tff", + required=False, + help="Font file path for recognition model.") + parser.add_argument( + "--crop_save_dir", type=str, required=False, help="Saving dir for images cropped of detection results." + ) + parser.add_argument( + "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." + ) + parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + + args = parser.parse_args() + setup_logger(args) + args = update_task_info(args) + check_and_update_args(args) + init_save_dir(args) + + return args + + +def setup_logger(args): + """ + initialize log system + """ + log.init_logger(args.show_log, args.save_log_dir) + log.save_args(args) + + +def update_task_info(args): + """ + add internal parameters according to different task type + """ + det = bool(args.det_model_name_or_config) + cls = bool(args.cls_model_name_or_config) + rec = bool(args.rec_model_name_or_config) + layout = bool(args.layout_model_name_or_config) + + task_map = { + (True, False, False, False): TaskType.DET, + (False, True, False, False): TaskType.CLS, + (False, False, True, False): TaskType.REC, + (True, False, True, False): TaskType.DET_REC, + (True, True, True, False): TaskType.DET_CLS_REC, + (False, False, False, True): TaskType.LAYOUT, + } + + task_order = (det, cls, rec, layout) + if task_order in task_map: + setattr(args, "task_type", task_map[task_order]) + else: + unsupported_task_map = { + (False, False, False, False): "empty", + (True, True, False, False): "det+cls", + (False, True, True, False): "cls+rec", + } + + raise ValueError( + f"Only support det, cls, rec, det+rec and det+cls+rec, but got {unsupported_task_map[task_order]}. " + f"Please check model_path!" + ) + + if args.det_model_name_or_config: + setattr(args, "det_config_path", get_config_by_name_for_model(args.det_model_name_or_config)) + else: + setattr(args, "det_config_path", None) + if args.cls_model_name_or_config: + setattr(args, "cls_config_path", get_config_by_name_for_model(args.cls_model_name_or_config)) + else: + setattr(args, "cls_config_path", None) + if args.rec_model_name_or_config: + setattr(args, "rec_config_path", get_config_by_name_for_model(args.rec_model_name_or_config)) + else: + setattr(args, "rec_config_path", None) + if args.layout_model_name_or_config: + setattr(args, "layout_config_path", get_config_by_name_for_model(args.layout_model_name_or_config)) + else: + setattr(args, "layout_config_path", None) + + return args + +def check_file(name, file): + if not os.path.exists(file): + raise ValueError(f"{name} must be a file, but {file} doesn't exist.") + if not os.path.isfile(file): + raise ValueError(f"{name} must be a file, but got a dir of {file}.") + +def check_positive(name, value): + if value < 1: + raise ValueError(f"{name} must be positive, but got {value}.") + + +def check_and_update_args(args): + """ + check parameters + """ + if not args.input_images_dir or not os.path.exists(args.input_images_dir): + raise ValueError("input_images_dir must be dir containing multiple images or path of single image.") + + if args.crop_save_dir and args.task_type not in (TaskType.DET_REC, TaskType.DET_CLS_REC): + raise ValueError("det_model_path and rec_model_path can't be empty when set crop_save_dir.") + + if args.vis_pipeline_save_dir and args.task_type not in (TaskType.DET_REC, TaskType.DET_CLS_REC): + raise ValueError("det_model_path and rec_model_path can't be empty when set vis_pipeline_save_dir.") + + if args.vis_det_save_dir and args.task_type != TaskType.DET: + raise ValueError( + "det_model_path can't be empty and cls_model_path/rec_model_path must be empty when set vis_det_save_dir " + "for single detection task." + ) + + if not args.res_save_dir: + raise ValueError("res_save_dir can't be empty.") + + need_check_file = { + "det_model_path": args.det_model_path, + "cls_model_path": args.cls_model_path, + "rec_model_path": args.rec_model_path, + } + for name, path in need_check_file.items(): + if path: + check_file(name, path) + with open(path) as fp: + yaml_cfg = Dict(yaml.safe_load(fp)) + check_file(name, yaml_cfg.predict.ckpt_load_path) + check_positive(name, yaml_cfg.predict.loader.batch_size) + + need_check_dir_not_same = { + "input_images_dir": args.input_images_dir, + "crop_save_dir": args.crop_save_dir, + "vis_pipeline_save_dir": args.vis_pipeline_save_dir, + "vis_det_save_dir": args.vis_det_save_dir, + } + for (name1, dir1), (name2, dir2) in itertools.combinations(need_check_dir_not_same.items(), 2): + if (dir1 and dir2) and os.path.realpath(os.path.normcase(dir1)) == os.path.realpath(os.path.normcase(dir2)): + raise ValueError(f"{name1} and {name2} can't be same path.") + + return args + + +def init_save_dir(args): + if args.res_save_dir: + save_path_init(args.res_save_dir, exist_ok=True) + if args.crop_save_dir: + save_path_init(args.crop_save_dir) + if args.vis_pipeline_save_dir: + save_path_init(args.vis_pipeline_save_dir) + if args.vis_det_save_dir: + save_path_init(args.vis_det_save_dir) + if args.save_log_dir: + save_path_init(args.save_log_dir, exist_ok=True) diff --git a/pipeline/parallel_pipeline.py b/pipeline/parallel_pipeline.py new file mode 100644 index 000000000..93115e84a --- /dev/null +++ b/pipeline/parallel_pipeline.py @@ -0,0 +1,98 @@ +import argparse +import os +import time +import sys +from collections import defaultdict +from multiprocessing import Manager, Process, Queue + +import numpy as np +import tqdm + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../"))) + +from framework.pipeline_manager import ParallelPipelineManager +from data_process.utils import cv_utils +from tasks import TaskType + +class ParallelPipeline: + def __init__(self, args: argparse.Namespace): + self.args = args + self.pipeline_manager = ParallelPipelineManager(args) + self.input_queue = self.pipeline_manager.input_queue + self.infer_params = {} + + def start_pipeline(self): + self.pipeline_manager.start_pipeline() + + def stop_pipeline(self): + self.pipeline_manager.stop_pipeline() + + def infer_for_images(self, input_images_dir, task_id=0): + self.infer_params = dict(**self.pipeline_manager.module_params) + self.send_image(input_images_dir, task_id) + + def fetch_result(self): + return self.pipeline_manager.fetch_result() + + def send_image(self, images: str, task_id=0): + """ + send image to input queue for pipeline + """ + if not (os.path.isdir(images) or os.path.isfile(images)): + raise ValueError("images must be a image path or dir.") + + # det, det(+cls)+rec + batch_num = 1 + + # cls, rec, layout + if self.args.task_type in (TaskType.CLS, TaskType.REC, TaskType.LAYOUT): + for name, value in self.infer_params.items(): + if name.endswith("_batch_num"): + batch_num = max(value) + + self._send_batch_image(images, batch_num, task_id) + + def _send_batch_image(self, images, batch_num, task_id): + if os.path.isdir(images): + show_progressbar = not self.args.show_log + input_image_list = [os.path.join(images, path) for path in os.listdir(images) if not path.endswith(".txt")] + images_num = len(input_image_list) + for i in ( + tqdm.tqdm(range(images_num), desc="send image to pipeline") if show_progressbar else range(images_num) + ): + if i % batch_num == 0: + batch_images = input_image_list[i : i + batch_num] + self.input_queue.put((batch_images, (images_num, task_id)), block=True) + else: + self.input_queue.put([[images], (1, task_id)], block=True) + + def infer_for_array(self, input_array, task_id=0): + self.infer_params = dict(**self.pipeline_manager.module_params) + self.send_array(input_array, task_id) + + def send_array(self, images, task_id): + if isinstance(images, np.ndarray): + self._send_batch_array([images], 1, task_id) + elif isinstance(images, (tuple, list)): + if len(images) == 0: + return + if not cv_utils.check_type_in_container(images, np.ndarray): + ValueError("unknown input data, images should be np.ndarray, or tuple&list contain np.ndarray") + # cls, rec, layout + batch_num = 1 + if self.args.task_type in (TaskType.CLS, TaskType.REC, TaskType.LAYOUT): + for name, value in self.infer_params.items(): + if name.endswith("_batch_num"): + batch_num = max(value) + self._send_batch_array(images, batch_num, task_id) + else: + raise ValueError(f"unknown input data: {type(images)}") + + def _send_batch_array(self, images, batch_num, task_id): + show_progressbar = not self.args.show_log + images_num = len(images) + for i in tqdm.tqdm(range(images_num), desc="send image to pipeline") if show_progressbar else range(images_num): + if i % batch_num == 0: + batch_images = images[i : i + batch_num] + self.input_queue.put([batch_images, (images_num, task_id)], block=True) diff --git a/pipeline/tasks/__init__.py b/pipeline/tasks/__init__.py new file mode 100644 index 000000000..b2de6fbc4 --- /dev/null +++ b/pipeline/tasks/__init__.py @@ -0,0 +1,20 @@ +from enum import Enum + + +class TaskType(Enum): + DET = 0 # Detection Model + CLS = 1 # Classification Model + REC = 2 # Recognition Model + DET_REC = 3 # Detection And Detection Model + DET_CLS_REC = 4 # Detection, Classification and Recognition Model + LAYOUT = 5 # Layout Model + + +SUPPORTED_TASK_BASIC_MODULE = { + TaskType.DET: [TaskType.DET], + TaskType.CLS: [TaskType.CLS], + TaskType.REC: [TaskType.REC], + TaskType.DET_REC: [TaskType.DET_REC], + TaskType.DET_CLS_REC: [TaskType.DET_CLS_REC], + TaskType.LAYOUT: [TaskType.LAYOUT], +} diff --git a/pipeline/utils/__init__.py b/pipeline/utils/__init__.py new file mode 100644 index 000000000..e04b55d61 --- /dev/null +++ b/pipeline/utils/__init__.py @@ -0,0 +1,12 @@ +from .adapted import get_config_by_name_for_model +from .logger import logger_instance as log +from .safe_utils import ( + check_valid_dir, + check_valid_file, + file_base_check, + safe_div, + safe_list_writer, + save_path_init, + suppress_stderr, + suppress_stdout, +) diff --git a/pipeline/utils/adapted/__init__.py b/pipeline/utils/adapted/__init__.py new file mode 100644 index 000000000..bbb5a5ea6 --- /dev/null +++ b/pipeline/utils/adapted/__init__.py @@ -0,0 +1,41 @@ +import os + +import yaml + +from .mindocr_models import MINDOCR_CONFIG_PATH, MINDOCR_MODELS +from .mmocr_models import MMOCR_CONFIG_PATH, MMOCR_MODELS +from .paddleocr_models import PADDLEOCR_CONFIG_PATH, PADDLEOCR_MODELS + +__all__ = ["get_config_by_name_for_model"] + + +def get_config_by_name_for_model(model_name_or_config: str): + if os.path.isfile(model_name_or_config): + filename = model_name_or_config + elif model_name_or_config in MINDOCR_MODELS: + filename = os.path.abspath(os.path.join(MINDOCR_CONFIG_PATH, MINDOCR_MODELS[model_name_or_config])) + elif model_name_or_config in PADDLEOCR_MODELS: + filename = os.path.abspath(os.path.join(PADDLEOCR_CONFIG_PATH, PADDLEOCR_MODELS[model_name_or_config])) + elif model_name_or_config in MMOCR_MODELS: + filename = os.path.abspath(os.path.join(MMOCR_CONFIG_PATH, MMOCR_MODELS[model_name_or_config])) + else: + raise ValueError( + f"The {model_name_or_config} must be a model name or YAML config file path, " + "please check whether the file exists, or whether model name is in the supported models list." + ) + + with open(filename) as fp: + cfg = yaml.safe_load(fp) + + try: + cfg["eval"]["dataset"]["transform_pipeline"] + cfg["postprocess"] + except KeyError: + preprocess_desc = "{eval: {dataset: {transform_pipeline: ...}}}" + postprocess_desc = "{postprocess: ...}" + raise ValueError( + f"The YAML config file {filename} must contain preprocess pipeline key {preprocess_desc} " + f"and postprocess key {postprocess_desc}." + ) + + return filename diff --git a/pipeline/utils/adapted/mindocr_models.py b/pipeline/utils/adapted/mindocr_models.py new file mode 100644 index 000000000..2858d7435 --- /dev/null +++ b/pipeline/utils/adapted/mindocr_models.py @@ -0,0 +1,16 @@ +import os + +MINDOCR_CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../../configs")) + +MINDOCR_MODELS = { + "en_ms_det_dbnet_resnet50": "det/dbnet/db_r50_icdar15.yaml", + "en_ms_det_dbnetpp_resnet50": "det/dbnet/dbpp_r50_icdar15.yaml", + "en_ms_det_psenet_resnet152": "det/psenet/pse_r152_icdar15.yaml", + "en_ms_det_psenet_resnet50": "det/psenet/pse_r50_icdar15.yaml", + "en_ms_det_psenet_mobilenetv3": "det/psenet/pse_mv3_icdar15.yaml", + "ch_ms_det_psenet_resnet152": "det/psenet/pse_r152_ctw1500.yaml", + "en_ms_rec_crnn_resnet34": "rec/crnn/crnn_resnet34.yaml", + "en_ms_det_east_resnet50": "det/east/east_r50_icdar15.yaml", + "en_ms_det_east_mobilenetv3": "det/east/east_mobilenetv3_icdar15.yaml", + "en_ms_rec_visionlan_resnet45": "rec/visionlan/visionlan_resnet45_LA.yaml", +} diff --git a/pipeline/utils/adapted/mmocr_models.py b/pipeline/utils/adapted/mmocr_models.py new file mode 100644 index 000000000..2970390a6 --- /dev/null +++ b/pipeline/utils/adapted/mmocr_models.py @@ -0,0 +1,13 @@ +import os + +MMOCR_CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../configs")) + +# fmt: off +MMOCR_MODELS = { + "en_mm_det_dbnetpp_resnet50": "det/mmocr/dbnetpp_resnet50_fpnc_1200e_icdar2015.yaml", # dbnet++ resnet50 + "en_mm_det_fcenet_resnet50": "det/mmocr/fcenet_resnet50_fpn_1500e_icdar2015.yaml", # fcenet resnet50 + "en_mm_rec_nrtr_resnet31": "rec/mmocr/nrtr_resnet31-1by8-1by4_6e_st_mj.yaml", # nrtr resnet31 + "en_mm_rec_satrn_shallowcnn": "rec/mmocr/satrn_shallow_5e_st_mj.yaml", # satrn shallow + +} +# fmt: on diff --git a/pipeline/utils/adapted/paddleocr_models.py b/pipeline/utils/adapted/paddleocr_models.py new file mode 100644 index 000000000..a1a43cda5 --- /dev/null +++ b/pipeline/utils/adapted/paddleocr_models.py @@ -0,0 +1,42 @@ +import os + +PADDLEOCR_CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../configs")) + +# fmt: off +PADDLEOCR_MODELS = { + "ch_pp_det_OCRv4": "det/ppocr/ch_PP-OCRv4_det_cml.yaml", # ch_PP-OCRv4_det + "ch_pp_server_det_v2.0": "det/ppocr/ch_det_res18_db_v2.0.yaml", # ch_ppocr_server_v2.0_det + "ch_pp_det_OCRv3": "det/ppocr/ch_PP-OCRv3_det_cml.yaml", # ch_PP-OCRv3_det + "ch_pp_server_rec_v2.0": "rec/ppocr/rec_chinese_common_v2.0.yaml", # ch_ppocr_server_v2.0_rec + "ch_pp_rec_OCRv3": "rec/ppocr/ch_PP-OCRv3_rec_distillation.yaml", # ch_PP-OCRv3_rec + "ch_pp_rec_OCRv4": "rec/ppocr/ch_PP-OCRv4_rec_distillation.yaml", # ch_PP-OCRv4_rec + "ch_pp_mobile_cls_v2.0": "cls/ppocr/cls_mv3.yaml", # ch_ppocr_mobile_v2.0_cls + "ch_pp_det_OCRv2": "det/ppocr/ch_PP-OCRv2_det_cml.yaml", # ch_PP-OCRv2_det + "ch_pp_mobile_det_v2.0_slim": "det/ppocr/ch_det_mv3_db_v2.0.yaml", # ch_ppocr_mobile_slim_v2.0_det + "ch_pp_mobile_det_v2.0": "det/ppocr/ch_det_mv3_db_v2.0.yaml", # ch_ppocr_mobile_v2.0_det + "en_pp_det_OCRv3": "det/ppocr/ch_PP-OCRv3_det_cml.yaml", # en_PP-OCRv3_det + "ml_pp_det_OCRv3": "det/ppocr/ch_PP-OCRv3_det_cml.yaml", # ml_PP-OCRv3_det + "ch_pp_rec_OCRv2": "rec/ppocr/ch_PP-OCRv2_rec_distillation.yaml", # ch_PP-OCRv2_rec + "ch_pp_mobile_rec_v2.0": "rec/ppocr/rec_chinese_lite_v2.0.yaml", # ch_ppocr_mobile_v2.0_rec + "en_pp_rec_OCRv3": "rec/ppocr/en_PP-OCRv3_rec.yaml", # en_PP-OCRv3_rec + "en_pp_mobile_rec_number_v2.0_slim": "rec/ppocr/rec_en_number_lite.yaml", # en_number_mobile_slim_v2.0_rec + "en_pp_mobile_rec_number_v2.0": "rec/ppocr/rec_en_number_lite.yaml", # en_number_mobile_v2.0_rec + "korean_pp_rec_OCRv3": "rec/ppocr/korean_PP-OCRv3_rec.yaml", # korean_PP-OCRv3_rec + "japan_pp_rec_OCRv3": "rec/ppocr/japan_PP-OCRv3_rec.yaml", # japan_PP-OCRv3_rec + "chinese_cht_pp_rec_OCRv3": "rec/ppocr/chinese_cht_PP-OCRv3_rec.yaml", # chinese_cht_PP-OCRv3_rec + "te_pp_rec_OCRv3": "rec/ppocr/te_PP-OCRv3_rec.yaml", # te_PP-OCRv3_rec + "ka_pp_rec_OCRv3": "rec/ppocr/ka_PP-OCRv3_rec.yaml", # ka_PP-OCRv3_rec + "ta_pp_rec_OCRv3": "rec/ppocr/ta_PP-OCRv3_rec.yaml", # ta_PP-OCRv3_rec + "latin_pp_rec_OCRv3": "rec/ppocr/latin_PP-OCRv3_rec.yaml", # latin_PP-OCRv3_rec + "arabic_pp_rec_OCRv3": "rec/ppocr/arabic_PP-OCRv3_rec.yaml", # arabic_PP-OCRv3_rec + "cyrillic_pp_rec_OCRv3": "rec/ppocr/cyrillic_PP-OCRv3_rec.yaml", # cyrillic_PP-OCRv3_rec + "devanagari_pp_rec_OCRv3": "rec/ppocr/devanagari_PP-OCRv3_rec.yaml", # devanagari_PP-OCRv3_rec + "en_pp_det_psenet_resnet50vd": "det/ppocr/det_r50_vd_pse.yaml", # pse_resnet50_vd + "en_pp_det_dbnet_resnet50vd": "det/ppocr/det_r50_vd_db.yaml", # dbnet resnet50_vd + "en_pp_det_east_resnet50vd": "det/ppocr/det_r50_vd_east.yaml", # east resnet50_vd + "en_pp_det_sast_resnet50vd": "det/ppocr/det_r50_vd_sast_icdar15.yaml", # sast resnet50_vd + "en_pp_rec_crnn_resnet34vd": "rec/ppocr/rec_r34_vd_none_bilstm_ctc.yaml", # crnn resnet34_vd + "en_pp_rec_rosetta_resnet34vd": "rec/ppocr/rec_r34_vd_none_none_ctc.yaml", # en_pp_rec_rosetta_resnet34vd + "en_pp_rec_vitstr_vitstr": "rec/ppocr/rec_vitstr_none_ce.yaml", # vitstr +} +# fmt: on diff --git a/pipeline/utils/logger.py b/pipeline/utils/logger.py new file mode 100644 index 000000000..bd95fecba --- /dev/null +++ b/pipeline/utils/logger.py @@ -0,0 +1,221 @@ +import argparse +import logging +import os +import sys +import threading +import time +from logging.handlers import RotatingFileHandler + +# Log level name and number mapping +_name_to_log_level = { + "ERROR": 40, + "WARNING": 30, + "INFO": 20, + "DEBUG": 10, +} + +# mindspore level and level name +_ms_level_to_name = { + "3": "ERROR", + "2": "WARNING", + "1": "INFO", + "0": "DEBUG", +} + +MAX_BYTES = 100 * 1024 * 1024 +BACKUP_COUNT = 10 +LOG_TYPE = "mindocr" +LOG_ENV = "MINDOCR_LOG_LEVEL" +INFER_INSTALL_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) + "/" + + +class DataFormatter(logging.Formatter): + """Log formatter""" + + def __init__(self, sub_module, fmt=None, **kwargs): + """ + Initialization of logFormatter. + :param sub_module: The submodule name. + :param fmt: Specified format pattern. Default: None. + :param kwargs: None + """ + super(DataFormatter, self).__init__(fmt=fmt, **kwargs) + self.sub_module = sub_module.upper() + + def formatTime(self, record, datefmt=None): + """ + Override formatTime for uniform format %Y-%m-%d-%H:%M:%S.SSS.SSS + :param record: Log record + :param datefmt: Date format + :return: formatted timestamp + """ + create_time = self.converter(record.created) + if datefmt: + return time.strftime(datefmt, create_time) + + timestamp = time.strftime("%Y-%m-%d-%H:%M:%S", create_time) + record_msecs = str(round(record.msecs * 1000)) + # Format the time stamp + return f"{timestamp}.{record_msecs[:3]}.{record_msecs[3:]}" + + def format(self, record): + """ + Apply log format with specified pattern. + :param record: Format pattern. + :return: formatted log content according to format pattern. + """ + if record.pathname.startswith(INFER_INSTALL_PATH): + # Get the relative path + record.filepath = record.pathname[len(INFER_INSTALL_PATH) :] + elif "/" in record.pathname: + record.filepath = record.pathname.strip().split("/")[-1] + else: + record.filepath = record.pathname + record.sub_module = self.sub_module + return super().format(record) + + +class RotatingLogFileHandler(RotatingFileHandler): + def _open(self): + return os.fdopen(os.open(self.baseFilename, os.O_RDWR | os.O_CREAT, 0o600), "a") + + +def _filter_env_level(): + log_env_level = os.getenv(LOG_ENV, "1") + if ( + not isinstance(log_env_level, str) + or not log_env_level.isdigit() + or int(log_env_level) < 0 + or int(log_env_level) > 3 + ): + log_env_level = "1" + return log_env_level + + +class LOGGER(logging.Logger): + def __init__(self, logger_name, log_level=logging.WARNING): + super(LOGGER, self).__init__(logger_name) + self.model_name = logger_name + self.data_formatter = DataFormatter(self.model_name, self._get_formatter()) + self.console_log_level = ( + _name_to_log_level.get(_ms_level_to_name.get(_filter_env_level())) if log_level is None else log_level + ) + console = logging.StreamHandler(sys.stdout) + console.setLevel(level=self.console_log_level) + console.setFormatter(self.data_formatter) + self.addHandler(console) + + @staticmethod + def _get_formatter(): + """ + + :return: str, the string of log formatter. + """ + formatter = ( + "[%(levelname)s] %(sub_module)s(%(process)d:" + "%(thread)d,%(processName)s):%(asctime)s " + "[%(filepath)s:%(lineno)d] %(message)s" + ) + return formatter + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and os.getenv("RANK_ID", "0") == "0": + self._log(logging.INFO, msg, args, **kwargs) + + def debug(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.DEBUG) and os.getenv("RANK_ID", "0") == "0": + self._log(logging.DEBUG, msg, args, **kwargs) + + def warning(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.WARNING) and os.getenv("RANK_ID", "0") == "0": + self._log(logging.WARNING, msg, args, **kwargs) + + def error(self, msg, *args, **kwargs): + rank_id = os.getenv("RANK_ID", None) + if rank_id and rank_id.isdigit() and 0 <= int(rank_id) < 8: + msg = f"[The error from this card id ({rank_id})] " + msg + if self.isEnabledFor(logging.ERROR): + self._log(logging.ERROR, msg, args, **kwargs) + + def setup_logging_file(self, log_dir, max_size=100 * 1024 * 1024, backup_cnt=10): + """Setup logging file.""" + if max_size > 1024 * 1024 * 1024 or max_size < 0: + logging.error("single log file size should more than 0, less than or equal to 1G.") + raise Exception("single log file size should more than 0, less than or equal to 1G.") + if backup_cnt > 100 or backup_cnt < 0: + logging.error("log file backup count should more than 0, less than or equal to 100") + raise Exception("log file backup count should more than 0, less than or equal to 100") + log_dir = os.path.realpath(log_dir) + if not os.path.exists(log_dir): + os.makedirs(log_dir, mode=0o750) + log_file_name = f"{self.model_name}.log" + log_fn = os.path.join(log_dir, log_file_name) + fh = RotatingLogFileHandler(log_fn, "a", max_size, backup_cnt) + fh.setFormatter(self.data_formatter) + fh.setLevel(logging.INFO) + self.addHandler(fh) + + def filter_log_str(self, msg) -> str: + def _check_str(need_check_str): + if len(need_check_str) > 10000: + self.warning("Input should be <= 10000") + return False + filter_strs = ["\r", "\n", "\\r", "\\n"] + for filter_str in filter_strs: + if filter_str in need_check_str: + self.warning("Input should not be included \\r or \\n") + return False + return True + + if isinstance(msg, str) and not _check_str(msg): + return "" + else: + return msg + + def save_args(self, args): + """ + :param args: input args param, just support argparse or dict + :return: None + """ + if isinstance(args, argparse.Namespace): + args = vars(args) + elif isinstance(args, dict): + pass + else: + logging.error("This api just support argparse or dict, please check your input type.") + raise Exception("This api just support argparse or dict, please check your input type.") + self.info("Args:") + args_copy = args.copy() + for key, value in args_copy.items(): + self.info("--> %s: %s", key, self.filter_log_str(args_copy[key])) + self.info("Finish read param") + + +class SingletonType(type): + _instance_lock = threading.Lock() + + def __call__(cls, *args, **kwargs): + if not hasattr(cls, "_instance"): + with SingletonType._instance_lock: + if not hasattr(cls, "_instance"): + cls._instance = super(SingletonType, cls).__call__(*args, **kwargs) + return cls._instance + + +class LoggerSystem(metaclass=SingletonType): + def __init__(self, model_name=LOG_TYPE, max_size=MAX_BYTES, backup_cnt=BACKUP_COUNT): + self.model_name = model_name + self.max_bytes = max_size + self.backup_count = backup_cnt + self.logger = None + + def init_logger(self, show_info_log=False, save_path=None): + self.logger = LOGGER(self.model_name, logging.INFO if show_info_log else logging.WARNING) + if save_path: + self.logger.setup_logging_file(save_path, self.max_bytes, self.backup_count) + + def __getattr__(self, item): + return object.__getattribute__(self.logger, item) + + +logger_instance = LoggerSystem(LOG_TYPE) diff --git a/pipeline/utils/safe_utils.py b/pipeline/utils/safe_utils.py new file mode 100644 index 000000000..a29760dec --- /dev/null +++ b/pipeline/utils/safe_utils.py @@ -0,0 +1,151 @@ +import contextlib +import json +import os +import re +import shutil +import stat + +from .logger import logger_instance as log + + +def safe_list_writer(save_dict, save_path): + """ + append the infer result to file. + :param save_dict: + :param save_path: + :return: + """ + flags, modes = os.O_WRONLY | os.O_CREAT | os.O_APPEND, stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP + with os.fdopen(os.open(save_path, flags, modes), "w") as f: + if not save_dict: + f.write("") + for filename, res in save_dict.items(): + content = os.path.basename(filename) + "\t" + json.dumps(res, ensure_ascii=False) + "\n" + f.write(content) + + +def safe_div(dividend, divisor): + try: + quotient = dividend / divisor + except ZeroDivisionError as error: + log.error(error) + quotient = 0 + return quotient + + +def verify_file_size(file_path) -> bool: + conf_file_size = os.path.getsize(file_path) + if conf_file_size > 0 and conf_file_size / 1024 / 1024 < 10: + return True + return False + + +def valid_characters(pattern: str, characters: str) -> bool: + if re.match(r".*[\s]+", characters): + return False + if not re.match(pattern, characters): + return False + return True + + +def file_base_check(file_path: str) -> None: + base_name = os.path.basename(file_path) + if not file_path or not os.path.isfile(file_path): + raise FileNotFoundError(f"the file:{base_name} does not exist!") + if not valid_characters("^[A-Za-z0-9_+-/]+$", file_path): + raise Exception(f"file path:{os.path.relpath(file_path)} should only include characters 'A-Za-z0-9+-_/'!") + if not verify_file_size(file_path): + raise Exception(f"{base_name}: the file size must between [1, 10M]!") + if os.path.islink(file_path): + raise Exception(f"the file:{base_name} is link. invalid file!") + if not os.access(file_path, mode=os.R_OK): + raise FileNotFoundError(f"the file:{base_name} is unreadable!") + + +def get_safe_name(path): + """Remove ending path separators before retrieving the basename. + + e.g. /xxx/ -> /xxx + """ + return os.path.basename(os.path.abspath(path)) + + +def custom_islink(path): + """Remove ending path separators before checking soft links. + + e.g. /xxx/ -> /xxx + """ + return os.path.islink(os.path.abspath(path)) + + +def check_valid_dir(path): + name = get_safe_name(path) + check_valid_path(path, name) + if not os.path.isdir(path): + log.error(f"Please check if {name} is a directory.") + raise NotADirectoryError("Check dir failed.") + + +def check_valid_path(path, name): + if not path or not os.path.exists(path): + raise FileExistsError(f"Error! {name} must exists!") + if custom_islink(path): + raise ValueError(f"Error! {name} cannot be a soft link!") + if not os.access(path, mode=os.R_OK): + raise RuntimeError(f"Error! Please check if {name} is readable.") + + +def check_valid_file(path, num_gb_limit=10): + filename = get_safe_name(path) + check_valid_path(path, filename) + if not os.path.isfile(path): + log.error(f"Please check if {filename} is a file.") + raise ValueError("Check file failed.") + check_size(path, filename, num_gb_limit=num_gb_limit) + + +def check_size(path, name, num_gb_limit): + limit = num_gb_limit * 1024 * 1024 * 1024 + size = os.path.getsize(path) + if size == 0: + raise ValueError(f"{name} cannot be an empty file!") + if size >= limit: + raise ValueError(f"The size of {name} must be smaller than {num_gb_limit} GB!") + + +def save_path_init(path, exist_ok=False): + if os.path.exists(path): + if exist_ok: + return + shutil.rmtree(path) + os.makedirs(path, 0o750) + + +@contextlib.contextmanager +def suppress_stdout(): + """ + A context manager for doing a "deep suppression" of stdout. + """ + null_fds = os.open(os.devnull, os.O_RDWR) + save_fds = os.dup(1) + os.dup2(null_fds, 1) + + yield + + os.dup2(save_fds, 1) + os.close(null_fds) + + +@contextlib.contextmanager +def suppress_stderr(): + """ + A context manager for doing a "deep suppression" of stderr. + """ + null_fds = os.open(os.devnull, os.O_RDWR) + save_fds = os.dup(2) + os.dup2(null_fds, 2) + + yield + + os.dup2(save_fds, 2) + os.close(null_fds) diff --git a/pipeline/utils/visual_c_results.py b/pipeline/utils/visual_c_results.py new file mode 100644 index 000000000..2b9ac0f3e --- /dev/null +++ b/pipeline/utils/visual_c_results.py @@ -0,0 +1,49 @@ +import argparse +import os + +import cv2 +import numpy as np +from tqdm import tqdm +from visual_utils import vis_bbox_text + + +def img_write(path: str, img: np.ndarray): + filename = os.path.abspath(path) + cv2.imencode(os.path.splitext(filename)[1], img)[1].tofile(filename) + + +def vis_results(prediction_result, vis_pipeline_save_dir, img_folder): + img_files = os.listdir(img_folder) + img_dict = {} + font_path = os.path.abspath("../../../../docs/fonts/simfang.ttf") + for img_name in img_files: + img = cv2.imread(os.path.join(img_folder, img_name)) # BGR format + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_dict[img_name] = img + + for each_pred in tqdm(prediction_result): + file_name, prediction = each_pred.split("\t") + basename = os.path.basename(file_name) + save_file = os.path.join(vis_pipeline_save_dir, os.path.splitext(basename)[0]) + prediction = eval(prediction) + box_list = [np.array(x["points"]).reshape(-1, 2) for x in prediction] + text_list = [x["transcription"] for x in prediction] + box_text = vis_bbox_text(img_dict[file_name], box_list, text_list, font_path=font_path) + img_write(save_file + ".jpg", box_text) + + +def read_prediction(prediction_folder): + with open(prediction_folder, "r", encoding="utf-8") as f: + prediction = f.readlines() + return prediction + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--img_folder", required=True, type=str) + parser.add_argument("--pred_dir", required=True, type=str) + parser.add_argument("--vis_dir", required=True, type=str) + args = parser.parse_args() + + prediction = read_prediction(args.pred_dir) + vis_results(prediction, args.vis_dir, args.img_folder) diff --git a/pipeline/utils/visual_utils.py b/pipeline/utils/visual_utils.py new file mode 100644 index 000000000..bccc59f4b --- /dev/null +++ b/pipeline/utils/visual_utils.py @@ -0,0 +1,129 @@ +""" +OCR visualization methods +""" +import math +import os.path +import random + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +__all__ = ["vis_bbox", "vis_bbox_text", "vis_crop"] + + +def vis_bbox(image, box_list, color, thickness): + """ + Draw a bounding box on an image. + :param image: input image + :param box_list: box list to add on image + :param color: color of the box + :param thickness: line thickness + :return: image with box + """ + + image = image.copy() + for box in box_list: + box = box.astype(int) + cv2.polylines(image, [box], True, color, thickness) + return image + + +def vis_bbox_text(image, box_list, text_list, font_path): + """ + Draw a bounding box and text on an image. + :param image: input image + :param box_list: box list to add on image + :param text_list: text list to add on image + :param font_path: path to font file + :return: image with box and text + """ + if font_path is None: + _font_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../docs/fonts/simfang.ttf")) + if os.path.isfile(_font_path): + font_path = _font_path + + image = Image.fromarray(image) + h, w = image.height, image.width + img_left = image.copy() + img_right = np.ones((h, w, 3), dtype=np.uint8) * 255 + random.seed(0) + + draw_left = ImageDraw.Draw(img_left) + if text_list is None or len(text_list) != len(box_list): + text_list = [None] * len(box_list) + for idx, (box, txt) in enumerate(zip(box_list, text_list)): + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + draw_left.polygon(box.astype(np.float32), fill=color) + img_right_text = draw_box_txt_fine((w, h), box, txt, font_path) + pts = np.array(box, np.int32).reshape((-1, 1, 2)) + cv2.polylines(img_right_text, [pts], True, color, 1) + img_right = cv2.bitwise_and(img_right, img_right_text) + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new(mode="RGB", size=(w * 2, h), color=(255, 255, 255)) # RGB or BGR doesn't matter + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h)) + return np.array(img_show) # RGB or BGR is the same as input image + + +def draw_box_txt_fine(img_size, box, txt, font_path): + box_height = int(math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)) + box_width = int(math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)) + if box_height > 2 * box_width and box_height > 30: + img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255)) # RGB or BGR doesn't matter + draw_text = ImageDraw.Draw(img_text) + if txt: + font = create_font(txt, (box_height, box_width), font_path) + draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) + img_text = img_text.transpose(Image.ROTATE_270) + else: + img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255)) # RGB or BGR doesn't matter + draw_text = ImageDraw.Draw(img_text) + if txt: + font = create_font(txt, (box_width, box_height), font_path) + draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) + pts1 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]) + pts2 = np.array(box, dtype=np.float32) + M = cv2.getPerspectiveTransform(pts1, pts2) + + img_text = np.array(img_text, dtype=np.uint8) + img_right_text = cv2.warpPerspective( + img_text, M, img_size, flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=(255, 255, 255) + ) + return img_right_text # RGB or BGR is the same as input image + + +def create_font(txt, sz, font_path): + font_size = int(sz[1] * 0.99) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + length = font.getsize(txt)[0] + if length > sz[0]: + font_size = int(font_size * sz[0] / length) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + return font + + +def vis_crop(image, box_list): + """ + Generate crop image + :param image: input image + :param box_list: list of box + :return List of Cropped Images + """ + image_crop = [] + for box in box_list: + if box.shape != (4, 2): + raise ValueError("shape of crop box must be 4*2") + box = box.astype(np.float32) + img_crop_width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3]))) + img_crop_height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height]]) + m = cv2.getPerspectiveTransform(box, pts_std) + dst_img = cv2.warpPerspective( + image, m, (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_width != 0 and dst_img_height / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + image_crop.append(dst_img) + return image_crop