-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
47 changed files
with
2,954 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .cls_infer_node import ClsInferNode | ||
from .cls_post_node import ClsPostNode | ||
from .cls_pre_node import ClsPreNode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import logging | ||
import os | ||
import time | ||
import sys | ||
import numpy as np | ||
import yaml | ||
from addict import Dict | ||
from typing import List | ||
|
||
__dir__ = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) | ||
|
||
from tools.infer.text.utils import get_ckpt_file | ||
from mindocr.data.transforms import create_transforms, run_transforms | ||
from mindocr.postprocess import build_postprocess | ||
from mindocr.infer.utils.model import MSModel, LiteModel | ||
|
||
|
||
algo_to_model_name = { | ||
"MV3": "cls_mobilenet_v3_small_100_model", | ||
} | ||
logger = logging.getLogger("mindocr") | ||
|
||
class ClsPreprocess(object): | ||
def __init__(self, args): | ||
self.args = args | ||
with open(args.cls_model_name_or_config, "r") as f: | ||
self.yaml_cfg = Dict(yaml.safe_load(f)) | ||
self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline) | ||
|
||
def __call__(self, img): | ||
data = {"image": img} | ||
# ZHQ TODO: [1:] ??? | ||
data = run_transforms(img, self.transforms[1:]) | ||
return data | ||
|
||
|
||
class ClsModelMS(MSModel): | ||
def __init__(self, args): | ||
self.args = args | ||
self.model_name = algo_to_model_name[args.cls_algorithm] | ||
self.config_path = args.cls_config_path | ||
self._init_model(self.model_name, self.config_path) | ||
|
||
|
||
class ClsModelLite(LiteModel): | ||
def __init__(self, args): | ||
self.args = args | ||
self.model_name = algo_to_model_name[args.cls_algorithm] | ||
self.config_path = args.cls_config_path | ||
self._init_model(self.model_name, self.config_path) | ||
|
||
INFER_CLS_MAP = {"MindSporeLite": ClsModelLite, "MindSpore": ClsModelMS} | ||
|
||
class ClsPostProcess(object): | ||
def __init__(self, args): | ||
self.args = args | ||
with open(args.cls_model_name_or_config, "r") as f: | ||
self.yaml_cfg = Dict(yaml.safe_load(f)) | ||
self.postprocessor = build_postprocess(self.yaml_cfg.postprocess) | ||
|
||
def __call__(self, pred): | ||
return self.postprocessor(pred) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
import time | ||
import sys | ||
import numpy as np | ||
import yaml | ||
from addict import Dict | ||
|
||
__dir__ = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) | ||
|
||
from pipeline.framework.module_base import ModuleBase | ||
from pipeline.tasks import TaskType | ||
from infer.classification.classification import INFER_CLS_MAP | ||
|
||
|
||
class ClsInferNode(ModuleBase): | ||
def __init__(self, args, msg_queue, tqdm_info): | ||
super(ClsInferNode, self).__init__(args, msg_queue, tqdm_info) | ||
self.args = args | ||
self.cls_model = None | ||
self.task_type = self.args.task_type | ||
|
||
def init_self_args(self): | ||
with open(self.args.cls_model_name_or_config, "r") as f: | ||
self.yaml_cfg = Dict(yaml.safe_load(f)) | ||
self.batch_size = self.yaml_cfg.predict.loader.batch_size | ||
ClsModel = INFER_CLS_MAP[self.yaml_cfg.predict.backend] | ||
self.cls_model = ClsModel(self.args) | ||
super().init_self_args() | ||
|
||
def process(self, input_data): | ||
""" | ||
Input: | ||
- input_data.data: [np.ndarray], shape:[3,w,h], e.g. [3,48,192] | ||
Output: | ||
- input_data.data: [np.ndarray], shape:[?,2] | ||
""" | ||
if input_data.skip: | ||
self.send_to_next_module(input_data) | ||
return | ||
|
||
data = input_data.data | ||
data = [np.expand_dims(d, 0) for d in data if len(d.shape) == 3] | ||
data = np.concatenate(data, axis=0) | ||
|
||
preds = [] | ||
for batch_i in range(data.shape[0] // self.batch_size + 1): | ||
start_i = batch_i * self.batch_size | ||
end_i = (batch_i + 1) * self.batch_size | ||
d = data[start_i:end_i] | ||
if d.shape[0] == 0: | ||
continue | ||
pred = self.cls_model([d]) | ||
preds.append(pred[0]) | ||
preds = np.concatenate(preds, axis=0) | ||
input_data.data = {"pred": preds} | ||
self.send_to_next_module(input_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
import time | ||
import sys | ||
import numpy as np | ||
import yaml | ||
from addict import Dict | ||
|
||
__dir__ = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
from pipeline.framework.module_base import ModuleBase | ||
from pipeline.tasks import TaskType | ||
from infer.classification.classification import ClsPostProcess | ||
from tools.infer.text.utils import crop_text_region | ||
from pipeline.data_process.utils.cv_utils import crop_box_from_image | ||
|
||
|
||
class ClsPostNode(ModuleBase): | ||
def __init__(self, args, msg_queue, tqdm_info): | ||
super(ClsPostNode, self).__init__(args, msg_queue, tqdm_info) | ||
self.cls_postprocess = ClsPostProcess(args) | ||
self.task_type = self.args.task_type | ||
self.cls_thresh = 0.9 | ||
|
||
def init_self_args(self): | ||
super().init_self_args() | ||
|
||
def process(self, input_data): | ||
""" | ||
Input: | ||
- input_data.data: [np.ndarray], shape:[?,2] | ||
Output: | ||
- input_data.sub_image_list: [np.ndarray], shape:[1,3,-1,-1], e.g. [1,3,48,192] | ||
- input_data.data = None | ||
or | ||
- input_data.infer_result = [(str, float)] | ||
""" | ||
if input_data.skip: | ||
self.send_to_next_module(input_data) | ||
return | ||
|
||
data = input_data.data | ||
pred = data["pred"] | ||
output = self.cls_postprocess(pred) | ||
angles = output["angles"] | ||
scores = np.array(output["scores"]).tolist() | ||
|
||
batch = input_data.sub_image_size | ||
if self.task_type.value == TaskType.DET_CLS_REC.value: | ||
sub_images = input_data.sub_image_list | ||
for i in range(batch): | ||
angle, score = angles[i], scores[i] | ||
if "180" == angle and score > self.cls_thresh: | ||
sub_images[i] = cv2.rotate(sub_images[i], cv2.ROTATE_180) | ||
input_data.sub_image_list = sub_images | ||
else: | ||
input_data.infer_result = [(angle, score) for angle, score in zip(angles, scores)] | ||
|
||
input_data.data = None | ||
self.send_to_next_module(input_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import argparse | ||
import os | ||
import time | ||
import sys | ||
import numpy as np | ||
|
||
__dir__ = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) | ||
|
||
from pipeline.framework.module_base import ModuleBase | ||
from pipeline.tasks import TaskType | ||
from infer.classification.classification import ClsPreProcess | ||
|
||
|
||
class ClsPreNode(ModuleBase): | ||
def __init__(self, args, msg_queue, tqdm_info): | ||
super(ClsPreNode, self).__init__(args, msg_queue, tqdm_info) | ||
self.cls_preprocesser = ClsPreProcess(args) | ||
self.task_type = self.args.task_type | ||
|
||
def init_self_args(self): | ||
super().init_self_args() | ||
return {"batch_size": 1} | ||
|
||
def process(self, input_data): | ||
if input_data.skip: | ||
self.send_to_next_module(input_data) | ||
return | ||
|
||
if self.task_type.value == TaskType.REC.value: | ||
image = input_data.frame[0] | ||
data = [self.cls_preprocesser(image)["image"]] | ||
input_data.sub_image_size = 1 | ||
input_data.data = data | ||
self.send_to_next_module(input_data) | ||
else: | ||
sub_image_list = input_data.sub_image_list | ||
data = [self.cls_preprocesser(image)["image"] for split_image in sub_image_list] | ||
input_data.sub_image_size = len(sub_image_list) | ||
input_data.data = data | ||
self.send_to_next_module(input_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .collect_node import CollectNode | ||
from .decode_node import DecodeNode | ||
from .handout_node import HandoutNode |
Oops, something went wrong.