|
1 | 1 | """ |
2 | 2 | Create and run transformations from a config or predefined transformation pipeline |
3 | 3 | """ |
4 | | -import logging |
5 | 4 | from typing import Dict, List |
6 | 5 |
|
7 | | -import numpy as np |
8 | | - |
9 | 6 | from .det_east_transforms import * |
10 | 7 | from .det_transforms import * |
11 | 8 | from .general_transforms import * |
|
15 | 12 | from .svtr_transform import * |
16 | 13 | from .table_transform import * |
17 | 14 |
|
| 15 | +SUPPORTED_TRANSFORMS = { |
| 16 | + "EASTProcessTrain": EASTProcessTrain, |
| 17 | + "DetLabelEncode": DetLabelEncode, |
| 18 | + "BorderMap": BorderMap, |
| 19 | + "ShrinkBinaryMap": ShrinkBinaryMap, |
| 20 | + "expand_poly": expand_poly, |
| 21 | + "PSEGtDecode": PSEGtDecode, |
| 22 | + "ValidatePolygons": ValidatePolygons, |
| 23 | + "RandomCropWithBBox": RandomCropWithBBox, |
| 24 | + "RandomCropWithMask": RandomCropWithMask, |
| 25 | + "DetResize": DetResize, |
| 26 | + "DecodeImage": DecodeImage, |
| 27 | + "NormalizeImage": NormalizeImage, |
| 28 | + "ToCHWImage": ToCHWImage, |
| 29 | + "PackLoaderInputs": PackLoaderInputs, |
| 30 | + "RandomScale": RandomScale, |
| 31 | + "RandomColorAdjust": RandomColorAdjust, |
| 32 | + "RandomRotate": RandomRotate, |
| 33 | + "RandomHorizontalFlip": RandomHorizontalFlip, |
| 34 | + "LayoutResize": LayoutResize, |
| 35 | + "ImageStridePad": ImageStridePad, |
| 36 | + "VQATokenLabelEncode": VQATokenLabelEncode, |
| 37 | + "VQATokenPad": VQATokenPad, |
| 38 | + "VQASerTokenChunk": VQASerTokenChunk, |
| 39 | + "VQAReTokenRelation": VQAReTokenRelation, |
| 40 | + "VQAReTokenChunk": VQAReTokenChunk, |
| 41 | + "TensorizeEntitiesRelations": TensorizeEntitiesRelations, |
| 42 | + "ABINetTransforms": ABINetTransforms, |
| 43 | + "ABINetRecAug": ABINetRecAug, |
| 44 | + "ABINetEval": ABINetEval, |
| 45 | + "ABINetEvalTransforms": ABINetEvalTransforms, |
| 46 | + "RecCTCLabelEncode": RecCTCLabelEncode, |
| 47 | + "RecAttnLabelEncode": RecAttnLabelEncode, |
| 48 | + "RecMasterLabelEncode": RecMasterLabelEncode, |
| 49 | + "VisionLANLabelEncode": VisionLANLabelEncode, |
| 50 | + "RecResizeImg": RecResizeImg, |
| 51 | + "RecResizeNormForInfer": RecResizeNormForInfer, |
| 52 | + "SVTRRecResizeImg": SVTRRecResizeImg, |
| 53 | + "Rotate90IfVertical": Rotate90IfVertical, |
| 54 | + "ClsLabelEncode": ClsLabelEncode, |
| 55 | + "SARLabelEncode": SARLabelEncode, |
| 56 | + "RobustScannerRecResizeImg": RobustScannerRecResizeImg, |
| 57 | + "SVTRRecAug": SVTRRecAug, |
| 58 | + "MultiLabelEncode": MultiLabelEncode, |
| 59 | + "RecConAug": RecConAug, |
| 60 | + "RecAug": RecAug, |
| 61 | + "RecResizeImgForSVTR": RecResizeImgForSVTR, |
| 62 | + "BaseRecLabelEncode": BaseRecLabelEncode, |
| 63 | + "AttnLabelEncode": AttnLabelEncode, |
| 64 | + "TableLabelEncode": TableLabelEncode, |
| 65 | + "TableMasterLabelEncode": TableMasterLabelEncode, |
| 66 | + "ResizeTableImage": ResizeTableImage, |
| 67 | + "PaddingTableImage": PaddingTableImage, |
| 68 | + "TableBoxEncode": TableBoxEncode, |
| 69 | + "TableImageNorm": TableImageNorm, |
| 70 | +} |
| 71 | + |
18 | 72 | __all__ = ["create_transforms", "run_transforms", "transforms_dbnet_icdar15"] |
19 | 73 | _logger = logging.getLogger(__name__) |
20 | 74 |
|
@@ -45,9 +99,9 @@ def create_transforms(transform_pipeline: List, global_config: Dict = None): |
45 | 99 | param = {} if transform_config[trans_name] is None else transform_config[trans_name] |
46 | 100 | if global_config is not None: |
47 | 101 | param.update(global_config) |
48 | | - # TODO: assert undefined transform class |
49 | | - |
50 | | - transform = eval(trans_name)(**param) |
| 102 | + # For security reasons, we no longer use the eval function to dynamically obtain class objects. |
| 103 | + # If you need to add a new transform class, please explicitly add it to the ``SUPPORTED_TRANSFORMS`` dict. |
| 104 | + transform = SUPPORTED_TRANSFORMS[trans_name](**param) |
51 | 105 | transforms.append(transform) |
52 | 106 | elif callable(transform_config): |
53 | 107 | transforms.append(transform_config) |
|
0 commit comments