Skip to content

Commit 0c5be08

Browse files
QYuFongQiuYuFong
andauthored
clean code (#826)
Co-authored-by: qiuyufeng <[email protected]>
1 parent 6dd7792 commit 0c5be08

File tree

6 files changed

+116
-32
lines changed

6 files changed

+116
-32
lines changed

mindocr/data/transforms/det_east_transforms.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ast
2+
import json
13
import math
24

35
import cv2
@@ -414,8 +416,17 @@ def _extract_vertices(self, data_labels):
414416
"""
415417
vertices_list = []
416418
labels_list = []
417-
data_labels = eval(data_labels)
418-
for data_label in data_labels:
419+
try:
420+
parsed_data = json.loads(data_labels)
421+
except json.JSONDecodeError:
422+
try:
423+
parsed_data = ast.literal_eval(data_labels)
424+
except (ValueError, SyntaxError) as e:
425+
raise ValueError(f"Invalid data format: {str(e)}") from e
426+
427+
if not isinstance(parsed_data, list):
428+
raise ValueError("Data labels should be a list")
429+
for data_label in parsed_data:
419430
vertices = data_label["points"]
420431
vertices = [item for point in vertices for item in point]
421432
vertices_list.append(vertices)

mindocr/data/transforms/svtr_transform.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,10 +546,19 @@ def __init__(self, max_text_length, character_dict_path=None, use_space_char=Fal
546546

547547
self.ctc_encode = CTCLabelEncodeForSVTR(max_text_length, character_dict_path, use_space_char, **kwargs)
548548
self.gtc_encode_type = gtc_encode
549+
# Pls explicitly specify the supported gtc_encode classes and obtain the class objects through dictionaries.
550+
supported_gtc_encode = {}
549551
if gtc_encode is None:
550552
self.gtc_encode = SARLabelEncodeForSVTR(max_text_length, character_dict_path, use_space_char, **kwargs)
551553
else:
552-
self.gtc_encode = eval(gtc_encode)(max_text_length, character_dict_path, use_space_char, **kwargs)
554+
# Mindocr currently does not have a module that requires a custom `gtc_encode` input parameter, and will not
555+
# enter this branch at present. If it is supported later, please directly obtain the class reference through
556+
# a specific dict, and do not use the `eval` function.
557+
if gtc_encode not in supported_gtc_encode:
558+
raise ValueError(f"Get unsupported gtc_encode {gtc_encode}")
559+
self.gtc_encode = supported_gtc_encode[gtc_encode](
560+
max_text_length, character_dict_path, use_space_char, **kwargs
561+
)
553562

554563
def __call__(self, data):
555564
data_ctc = copy.deepcopy(data)
@@ -925,7 +934,7 @@ def __init__(
925934
jitter_prob=0.4,
926935
blur_prob=0.4,
927936
hsv_aug_prob=0.4,
928-
**kwargs
937+
**kwargs,
929938
):
930939
self.crop_prob = crop_prob
931940
self.reverse_prob = reverse_prob
@@ -973,7 +982,7 @@ def __init__(
973982
jitter_prob=0.4,
974983
blur_prob=0.4,
975984
hsv_aug_prob=0.4,
976-
**kwargs
985+
**kwargs,
977986
):
978987
self.tia_prob = tia_prob
979988
self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob, jitter_prob, blur_prob, hsv_aug_prob)
@@ -1078,7 +1087,7 @@ def __init__(
10781087
character_dict_path=".mindocr/utils/dict/ch_dict.txt",
10791088
padding=True,
10801089
width_downsample_ratio=0.125,
1081-
**kwargs
1090+
**kwargs,
10821091
):
10831092
self.image_shape = image_shape
10841093
self.infer_mode = infer_mode

mindocr/data/transforms/transforms_factory.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
"""
22
Create and run transformations from a config or predefined transformation pipeline
33
"""
4-
import logging
54
from typing import Dict, List
65

7-
import numpy as np
8-
96
from .det_east_transforms import *
107
from .det_transforms import *
118
from .general_transforms import *
@@ -15,6 +12,63 @@
1512
from .svtr_transform import *
1613
from .table_transform import *
1714

15+
SUPPORTED_TRANSFORMS = {
16+
"EASTProcessTrain": EASTProcessTrain,
17+
"DetLabelEncode": DetLabelEncode,
18+
"BorderMap": BorderMap,
19+
"ShrinkBinaryMap": ShrinkBinaryMap,
20+
"expand_poly": expand_poly,
21+
"PSEGtDecode": PSEGtDecode,
22+
"ValidatePolygons": ValidatePolygons,
23+
"RandomCropWithBBox": RandomCropWithBBox,
24+
"RandomCropWithMask": RandomCropWithMask,
25+
"DetResize": DetResize,
26+
"DecodeImage": DecodeImage,
27+
"NormalizeImage": NormalizeImage,
28+
"ToCHWImage": ToCHWImage,
29+
"PackLoaderInputs": PackLoaderInputs,
30+
"RandomScale": RandomScale,
31+
"RandomColorAdjust": RandomColorAdjust,
32+
"RandomRotate": RandomRotate,
33+
"RandomHorizontalFlip": RandomHorizontalFlip,
34+
"LayoutResize": LayoutResize,
35+
"ImageStridePad": ImageStridePad,
36+
"VQATokenLabelEncode": VQATokenLabelEncode,
37+
"VQATokenPad": VQATokenPad,
38+
"VQASerTokenChunk": VQASerTokenChunk,
39+
"VQAReTokenRelation": VQAReTokenRelation,
40+
"VQAReTokenChunk": VQAReTokenChunk,
41+
"TensorizeEntitiesRelations": TensorizeEntitiesRelations,
42+
"ABINetTransforms": ABINetTransforms,
43+
"ABINetRecAug": ABINetRecAug,
44+
"ABINetEval": ABINetEval,
45+
"ABINetEvalTransforms": ABINetEvalTransforms,
46+
"RecCTCLabelEncode": RecCTCLabelEncode,
47+
"RecAttnLabelEncode": RecAttnLabelEncode,
48+
"RecMasterLabelEncode": RecMasterLabelEncode,
49+
"VisionLANLabelEncode": VisionLANLabelEncode,
50+
"RecResizeImg": RecResizeImg,
51+
"RecResizeNormForInfer": RecResizeNormForInfer,
52+
"SVTRRecResizeImg": SVTRRecResizeImg,
53+
"Rotate90IfVertical": Rotate90IfVertical,
54+
"ClsLabelEncode": ClsLabelEncode,
55+
"SARLabelEncode": SARLabelEncode,
56+
"RobustScannerRecResizeImg": RobustScannerRecResizeImg,
57+
"SVTRRecAug": SVTRRecAug,
58+
"MultiLabelEncode": MultiLabelEncode,
59+
"RecConAug": RecConAug,
60+
"RecAug": RecAug,
61+
"RecResizeImgForSVTR": RecResizeImgForSVTR,
62+
"BaseRecLabelEncode": BaseRecLabelEncode,
63+
"AttnLabelEncode": AttnLabelEncode,
64+
"TableLabelEncode": TableLabelEncode,
65+
"TableMasterLabelEncode": TableMasterLabelEncode,
66+
"ResizeTableImage": ResizeTableImage,
67+
"PaddingTableImage": PaddingTableImage,
68+
"TableBoxEncode": TableBoxEncode,
69+
"TableImageNorm": TableImageNorm,
70+
}
71+
1872
__all__ = ["create_transforms", "run_transforms", "transforms_dbnet_icdar15"]
1973
_logger = logging.getLogger(__name__)
2074

@@ -45,9 +99,9 @@ def create_transforms(transform_pipeline: List, global_config: Dict = None):
4599
param = {} if transform_config[trans_name] is None else transform_config[trans_name]
46100
if global_config is not None:
47101
param.update(global_config)
48-
# TODO: assert undefined transform class
49-
50-
transform = eval(trans_name)(**param)
102+
# For security reasons, we no longer use the eval function to dynamically obtain class objects.
103+
# If you need to add a new transform class, please explicitly add it to the ``SUPPORTED_TRANSFORMS`` dict.
104+
transform = SUPPORTED_TRANSFORMS[trans_name](**param)
51105
transforms.append(transform)
52106
elif callable(transform_config):
53107
transforms.append(transform_config)

mindocr/models/backbones/mindcv_models/download.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,11 @@ def extract_archive(self, from_path: str, to_path: str = None) -> str:
9999
def download_file(self, url: str, file_path: str, chunk_size: int = 1024):
100100
"""Download a file."""
101101

102-
# no check certificate
102+
# For security reasons, this repository code does not provide a function to disable SSL.
103+
# If necessary, please disable SSL verification yourself.
103104
ctx = ssl.create_default_context()
104-
ctx.check_hostname = False
105-
ctx.verify_mode = ssl.CERT_NONE
105+
# ctx.check_hostname = False
106+
# ctx.verify_mode = ssl.CERT_NONE
106107

107108
# Define request headers.
108109
headers = {"User-Agent": self.USER_AGENT}

mindocr/postprocess/builder.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,24 @@
2323

2424
__all__ = ["build_postprocess"]
2525

26-
supported_postprocess = (
27-
det_db_postprocess.__all__
28-
+ det_pse_postprocess.__all__
29-
+ det_east_postprocess.__all__
30-
+ rec_postprocess.__all__
31-
+ cls_postprocess.__all__
32-
+ rec_abinet_postprocess.__all__
33-
+ kie_ser_postprocess.__all__
34-
+ kie_re_postprocess.__all__
35-
+ layout_postprocess.__all__
36-
+ table_postprocess.__all__
37-
)
26+
SUPPORTED_POSTPROCESS = {
27+
"DBPostprocess": DBPostprocess,
28+
"PSEPostprocess": PSEPostprocess,
29+
"EASTPostprocess": EASTPostprocess,
30+
"CTCLabelDecode": CTCLabelDecode,
31+
"RecCTCLabelDecode": RecCTCLabelDecode,
32+
"RecAttnLabelDecode": RecAttnLabelDecode,
33+
"RecMasterLabelDecode": RecMasterLabelDecode,
34+
"VisionLANPostProcess": VisionLANPostProcess,
35+
"SARLabelDecode": SARLabelDecode,
36+
"ClsPostprocess": ClsPostprocess,
37+
"ABINetLabelDecode": ABINetLabelDecode,
38+
"VQASerTokenLayoutLMPostProcess": VQASerTokenLayoutLMPostProcess,
39+
"VQAReTokenLayoutLMPostProcess": VQAReTokenLayoutLMPostProcess,
40+
"YOLOv8Postprocess": YOLOv8Postprocess,
41+
"Layoutlmv3Postprocess": Layoutlmv3Postprocess,
42+
"TableMasterLabelDecode": TableMasterLabelDecode,
43+
}
3844

3945

4046
def build_postprocess(config: dict):
@@ -57,11 +63,11 @@ def build_postprocess(config: dict):
5763
>>> postprocess
5864
"""
5965
proc = config.pop("name")
60-
if proc in supported_postprocess:
61-
postprocessor = eval(proc)(**config)
62-
elif proc is None:
66+
if proc is None:
6367
return None
68+
if proc in SUPPORTED_POSTPROCESS:
69+
postprocessor = SUPPORTED_POSTPROCESS[proc](**config)
6470
else:
65-
raise ValueError(f"Invalid postprocess name {proc}, support postprocess are {supported_postprocess}")
71+
raise ValueError(f"Invalid postprocess name {proc}, support postprocess are {SUPPORTED_POSTPROCESS.keys()}")
6672

6773
return postprocessor

tools/arg_parser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def _parse_options(opts: list):
6262
"=" in opt_str
6363
), "Invalid option {}. A valid option must be in the format of {{key_name}}={{value}}".format(opt_str)
6464
k, v = opt_str.strip().split("=")
65-
options[k] = yaml.load(v, Loader=yaml.Loader)
65+
try:
66+
options[k] = yaml.load(v, Loader=yaml.SafeLoader)
67+
except yaml.YAMLError as e:
68+
raise ValueError(f"Failed to parse value for key '{k}': {str(e)}") from e
6669
# print('Parsed options: ', options)
6770

6871
return options

0 commit comments

Comments
 (0)