diff --git a/configs/kie/layoutlmv3/README.md b/configs/kie/layoutlmv3/README.md index cc18fd725..0c8b8a14e 100644 --- a/configs/kie/layoutlmv3/README.md +++ b/configs/kie/layoutlmv3/README.md @@ -144,7 +144,7 @@ Apart from the dataset setting, please also check the following important args: system: mode: distribute: False # `True` for distributed training, `False` for standalone training - amp_level: 'O0' + amp_level: 'O3' seed: 42 val_while_train: True # Validate while training drop_overflow_update: False @@ -157,7 +157,6 @@ model: name: TokenClassificationHead num_classes: 7 use_visual_backbone: True - use_float16: True pretrained: ... train: diff --git a/configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml b/configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml index b42c84116..5c5944ed7 100644 --- a/configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml +++ b/configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml @@ -1,7 +1,7 @@ system: mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore distribute: False - amp_level: "O0" + amp_level: "O3" seed: 42 log_interval: 10 val_start_epoch: 50 @@ -17,7 +17,6 @@ model: name: TokenClassificationHead num_classes: 7 use_visual_backbone: True - use_float16: True pretrained: postprocess: diff --git a/configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml b/configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml index 0e7935c9d..a59c4ff44 100644 --- a/configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml +++ b/configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml @@ -1,7 +1,7 @@ system: mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore distribute: False - amp_level: 'O0' + amp_level: 'O3' seed: 42 log_interval: 10 val_while_train: True @@ -16,11 +16,9 @@ model: pretrained: True num_classes: &num_classes 7 use_visual_backbone: True - use_float16: True head: name: RelationExtractionHead use_visual_backbone: True - use_float16: True pretrained: postprocess: diff --git a/configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml b/configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml index e371bcf49..64fd4b328 100644 --- a/configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml +++ b/configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml @@ -1,7 +1,7 @@ system: mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore distribute: False - amp_level: 'O0' + amp_level: 'O3' seed: 42 log_interval: 10 val_while_train: True @@ -15,12 +15,10 @@ model: pretrained: True num_classes: &num_classes 7 use_visual_backbone: True - use_float16: True head : name: TokenClassificationHead num_classes: 7 use_visual_backbone: True - use_float16: True pretrained: postprocess: diff --git a/configs/kie/vi_layoutxlm/README.md b/configs/kie/vi_layoutxlm/README.md index 5326f11f5..f26f5995c 100644 --- a/configs/kie/vi_layoutxlm/README.md +++ b/configs/kie/vi_layoutxlm/README.md @@ -159,7 +159,7 @@ Apart from the dataset setting, please also check the following important args: system: mode: distribute: False # `True` for distributed training, `False` for standalone training - amp_level: 'O0' + amp_level: 'O3' seed: 42 val_while_train: True # Validate while training drop_overflow_update: False @@ -171,12 +171,10 @@ model: pretrained: True num_classes: &num_classes 7 use_visual_backbone: False - use_float16: True head : name: TokenClassificationHead num_classes: 7 use_visual_backbone: False - use_float16: True pretrained: ... train: diff --git a/configs/kie/vi_layoutxlm/README_CN.md b/configs/kie/vi_layoutxlm/README_CN.md index f6f843807..4da2a631f 100644 --- a/configs/kie/vi_layoutxlm/README_CN.md +++ b/configs/kie/vi_layoutxlm/README_CN.md @@ -156,7 +156,7 @@ eval: system: mode: distribute: False # 分布式训练为True,单卡训练为False - amp_level: 'O0' + amp_level: 'O3' seed: 42 val_while_train: True # 边训练边验证 drop_overflow_update: False @@ -168,12 +168,10 @@ model: pretrained: True num_classes: &num_classes 7 use_visual_backbone: False - use_float16: True head : name: TokenClassificationHead num_classes: 7 use_visual_backbone: False - use_float16: True pretrained: ... train: diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml index 2fae9ec62..cd8fc57fe 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml @@ -1,7 +1,7 @@ system: mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore distribute: False - amp_level: "O0" + amp_level: "O3" seed: 42 log_interval: 10 val_while_train: True @@ -16,11 +16,9 @@ model: pretrained: True num_classes: &num_classes 7 use_visual_backbone: False - use_float16: True head: name: RelationExtractionHead use_visual_backbone: False - use_float16: True pretrained: postprocess: @@ -90,11 +88,11 @@ train: "bbox", "attention_mask", "token_type_ids", - "image", "question", "question_label", "answer", "answer_label", + "image", "relation_label", ] net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns diff --git a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml index 3682fc45c..647724df9 100644 --- a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml +++ b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml @@ -1,7 +1,7 @@ system: mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore distribute: False - amp_level: 'O0' + amp_level: 'O3' seed: 42 log_interval: 10 val_while_train: True @@ -15,12 +15,10 @@ model: pretrained: True num_classes: &num_classes 7 use_visual_backbone: False - use_float16: True head : name: TokenClassificationHead num_classes: 7 use_visual_backbone: False - use_float16: True pretrained: postprocess: diff --git a/mindocr/losses/kie_loss.py b/mindocr/losses/kie_loss.py index 60fd7630f..8b864282b 100644 --- a/mindocr/losses/kie_loss.py +++ b/mindocr/losses/kie_loss.py @@ -28,6 +28,6 @@ def __init__(self, **kwargs): super().__init__() self.loss_fct = nn.CrossEntropyLoss() - def construct(self, predicts, attention_mask, labels): + def construct(self, predicts, labels): loss = self.loss_fct(predicts.transpose(0, 2, 1), labels.astype(ms.int32)) return loss diff --git a/mindocr/models/backbones/layoutlmv3/configuration.py b/mindocr/models/backbones/layoutlmv3/configuration.py index 93243ddb5..259719e79 100644 --- a/mindocr/models/backbones/layoutlmv3/configuration.py +++ b/mindocr/models/backbones/layoutlmv3/configuration.py @@ -3,9 +3,8 @@ @dataclass class LayoutLMv3PretrainedConfig: - def __init__(self, use_float16=False): + def __init__(self): pretrained_config = { - "use_float16": use_float16, "fast_qkv": False, "vocab_size": 250002, "hidden_size": 768, diff --git a/mindocr/models/backbones/layoutlmv3/layoutlmv3.py b/mindocr/models/backbones/layoutlmv3/layoutlmv3.py index 1e1bc1f9b..7cf67a49d 100644 --- a/mindocr/models/backbones/layoutlmv3/layoutlmv3.py +++ b/mindocr/models/backbones/layoutlmv3/layoutlmv3.py @@ -175,7 +175,7 @@ def construct( if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) - attention_scores = attention_scores + attention_mask.astype(self.dense_dtype) + attention_scores = attention_scores + attention_mask.astype(attention_scores.dtype) # Normalize the attention scores to probabilities. # Use the trick of the CogView paper to stablize training @@ -227,11 +227,8 @@ def __init__(self, config): self.has_relative_attention_bias = config.has_relative_attention_bias self.has_spatial_attention_bias = config.has_spatial_attention_bias self.patch_size = config.patch_size - self.use_float16 = config.use_float16 - self.dense_dtype = mstype.float32 - if self.use_float16 is True: - self.dense_dtype = mstype.float16 - self.min = finfo(self.dense_dtype) + self.float32_min = finfo(mstype.float32) + self.float16_min = finfo(mstype.float16) self.out_channels = 1 self.use_visual_backbone = True @@ -342,7 +339,13 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape, dtype # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. # fp16 compatibility extended_attention_mask = extended_attention_mask.astype(dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * self.min + + if dtype == mstype.float32: + minimum = self.float32_min + elif dtype == mstype.float16: + minimum = self.float16_min + + extended_attention_mask = (1.0 - extended_attention_mask) * minimum return extended_attention_mask def get_head_mask(self, head_mask, num_hidden_layers: int, is_attention_chunked: bool = False): @@ -518,7 +521,7 @@ def construct( @register_backbone -def layoutlmv3(use_float16: bool = True, **kwargs): - pretrained_config = LayoutLMv3PretrainedConfig(use_float16) +def layoutlmv3(**kwargs): + pretrained_config = LayoutLMv3PretrainedConfig() model = LayoutLMv3Model(pretrained_config) return model diff --git a/mindocr/models/backbones/layoutxlm/configuration.py b/mindocr/models/backbones/layoutxlm/configuration.py index ece3b3b7b..241d524e5 100644 --- a/mindocr/models/backbones/layoutxlm/configuration.py +++ b/mindocr/models/backbones/layoutxlm/configuration.py @@ -3,13 +3,11 @@ @dataclass class LayoutXLMPretrainedConfig: - def __init__(self, use_visual_backbone=True, use_float16=False): + def __init__(self, use_visual_backbone=True): pretrained_config = { "use_visual_backbone": use_visual_backbone, - "use_float16": use_float16, "attention_probs_dropout_prob": 0.1, "use_visual_backbone": use_visual_backbone, - "use_float16": use_float16, "bos_token_id": 0, "coordinate_size": 128, "eos_token_id": 2, diff --git a/mindocr/models/backbones/layoutxlm/layoutxlm.py b/mindocr/models/backbones/layoutxlm/layoutxlm.py index 4737c03a4..8fe78893c 100644 --- a/mindocr/models/backbones/layoutxlm/layoutxlm.py +++ b/mindocr/models/backbones/layoutxlm/layoutxlm.py @@ -99,18 +99,12 @@ def __init__(self, config): self.has_visual_segment_embedding = config.has_visual_segment_embedding self.embeddings = LayoutXLMEmbeddings(config) self.use_visual_backbone = config.use_visual_backbone - self.use_float16 = config.use_float16 - self.dense_dtype = mstype.float32 - if self.use_float16 is True: - self.dense_dtype = mstype.float16 if self.use_visual_backbone is True: set_context(jit_syntax_level=0) self.visual = VisualBackbone(config) self.visual.freeze() - self.visual_proj = nn.Dense(config.image_feature_pool_shape[-1], config.hidden_size).to_float( - self.dense_dtype - ) + self.visual_proj = nn.Dense(config.image_feature_pool_shape[-1], config.hidden_size) if self.has_visual_segment_embedding: self.visual_segment_embedding = Parameter(nn.Embedding(1, config.hidden_size).embedding_table[0]) self.visual_LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps) @@ -302,8 +296,8 @@ def construct( @register_backbone -def layoutxlm(pretrained: bool = True, use_visual_backbone: bool = True, use_float16: bool = False, **kwargs): - pretrained_config = LayoutXLMPretrainedConfig(use_visual_backbone, use_float16) +def layoutxlm(pretrained: bool = True, use_visual_backbone: bool = True, **kwargs): + pretrained_config = LayoutXLMPretrainedConfig(use_visual_backbone) model = LayoutXLMModel(pretrained_config) if pretrained: if use_visual_backbone is True: diff --git a/mindocr/models/backbones/layoutxlm/visual_backbone.py b/mindocr/models/backbones/layoutxlm/visual_backbone.py index 763824227..c9fc0cec5 100644 --- a/mindocr/models/backbones/layoutxlm/visual_backbone.py +++ b/mindocr/models/backbones/layoutxlm/visual_backbone.py @@ -130,7 +130,7 @@ def construct(self, x): top_block_in_feature = bottom_up_features[self.top_block.in_feature] else: top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)] - results.extend(self.top_block(top_block_in_feature.astype(ms.float16))) + results.extend(self.top_block(top_block_in_feature)) assert len(self._out_features) == len(results) diff --git a/mindocr/models/backbones/transformer_common/layer.py b/mindocr/models/backbones/transformer_common/layer.py index d906a26a8..8757cd123 100644 --- a/mindocr/models/backbones/transformer_common/layer.py +++ b/mindocr/models/backbones/transformer_common/layer.py @@ -11,9 +11,9 @@ def finfo(dtype): if dtype == mstype.float32: - return np.finfo(np.float32).min + return Tensor(np.finfo(np.float32).min) elif dtype == mstype.float16: - return np.finfo(np.float16).min + return Tensor(np.finfo(np.float16).min) else: raise TypeError(f"For 'finfo', the input dtype should be float32 or float16, bug got {dtype}") @@ -94,24 +94,18 @@ def __init__(self, config): self.has_relative_attention_bias = config.has_relative_attention_bias self.has_spatial_attention_bias = config.has_spatial_attention_bias - self.use_float16 = config.use_float16 - self.dense_dtype = mstype.float32 - if self.use_float16 is True: - self.dense_dtype = mstype.float16 - if config.fast_qkv: - self.qkv_linear = nn.Dense(config.hidden_size, 3 * self.all_head_size, has_bias=False).to_float( - self.dense_dtype - ) - self.q_bias = Parameter(initializer(Constant(0.0), [1, 1, self.all_head_size], self.dense_dtype)) - self.v_bias = Parameter(initializer(Constant(0.0), [1, 1, self.all_head_size], self.dense_dtype)) + self.qkv_linear = nn.Dense(config.hidden_size, 3 * self.all_head_size, has_bias=False) + self.q_bias = Parameter(initializer(Constant(0.0), [1, 1, self.all_head_size])) + self.v_bias = Parameter(initializer(Constant(0.0), [1, 1, self.all_head_size])) else: - self.query = nn.Dense(config.hidden_size, self.all_head_size).to_float(self.dense_dtype) - self.key = nn.Dense(config.hidden_size, self.all_head_size).to_float(self.dense_dtype) - self.value = nn.Dense(config.hidden_size, self.all_head_size).to_float(self.dense_dtype) + self.query = nn.Dense(config.hidden_size, self.all_head_size) + self.key = nn.Dense(config.hidden_size, self.all_head_size) + self.value = nn.Dense(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(p=config.attention_probs_dropout_prob) - self.min = finfo(self.dense_dtype) + self.float32_min = finfo(mstype.float32) + self.float16_min = finfo(mstype.float16) def transpose_for_scores(self, x): new_x_shape = list(x.shape[:-1]) + [ @@ -168,10 +162,20 @@ def construct( attention_scores += rel_pos if self.has_spatial_attention_bias: attention_scores += rel_2d_pos + + minimum = None + if attention_scores.dtype == mstype.float32: + minimum = self.float32_min + elif attention_scores.dtype == mstype.float16: + minimum = self.float16_min + else: + raise ValueError("Dtype of attention_scores must be ms.float32 or ms.float16,", + f" but got {attention_scores.dtype}") + attention_scores = ops.masked_fill( attention_scores, ops.stop_gradient(attention_mask.astype(mstype.bool_)), - self.min, + minimum ) attention_probs = ops.softmax(attention_scores, axis=-1) # This is actually dropping out entire tokens to attend to, which might @@ -193,11 +197,7 @@ def construct( class LayoutXLMSelfOutput(nn.Cell): def __init__(self, config): super().__init__() - self.use_float16 = config.use_float16 - self.dense_dtype = mstype.float32 - if self.use_float16 is True: - self.dense_dtype = mstype.float16 - self.dense = nn.Dense(config.hidden_size, config.hidden_size).to_float(self.dense_dtype) + self.dense = nn.Dense(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps) self.dropout = nn.Dropout(p=config.hidden_dropout_prob) @@ -251,11 +251,7 @@ def construct( class LayoutXLMIntermediate(nn.Cell): def __init__(self, config): super().__init__() - self.use_float16 = config.use_float16 - self.dense_dtype = mstype.float32 - if self.use_float16 is True: - self.dense_dtype = mstype.float16 - self.dense = nn.Dense(config.hidden_size, config.intermediate_size).to_float(self.dense_dtype) + self.dense = nn.Dense(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = act_fn[config.hidden_act] else: @@ -270,11 +266,7 @@ def construct(self, hidden_states): class LayoutXLMOutput(nn.Cell): def __init__(self, config): super().__init__() - self.use_float16 = config.use_float16 - self.dense_dtype = mstype.float32 - if self.use_float16 is True: - self.dense_dtype = mstype.float16 - self.dense = nn.Dense(config.intermediate_size, config.hidden_size).to_float(self.dense_dtype) + self.dense = nn.Dense(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps) self.dropout = nn.Dropout(p=config.hidden_dropout_prob) @@ -347,18 +339,11 @@ def __init__(self, config): self.has_relative_attention_bias = config.has_relative_attention_bias self.has_spatial_attention_bias = config.has_spatial_attention_bias - self.use_float16 = config.use_float16 - self.dense_dtype = mstype.float32 - if self.use_float16 is True: - self.dense_dtype = mstype.float16 - if self.has_relative_attention_bias: self.rel_pos_bins = config.rel_pos_bins self.max_rel_pos = config.max_rel_pos self.rel_pos_onehot_size = config.rel_pos_bins - self.rel_pos_bias = nn.Dense(self.rel_pos_onehot_size, config.num_attention_heads, has_bias=False).to_float( - mstype.float16 - ) + self.rel_pos_bias = nn.Dense(self.rel_pos_onehot_size, config.num_attention_heads, has_bias=False) if self.has_spatial_attention_bias: self.max_rel_2d_pos = config.max_rel_2d_pos @@ -366,10 +351,10 @@ def __init__(self, config): self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins self.rel_pos_x_bias = nn.Dense( self.rel_2d_pos_onehot_size, config.num_attention_heads, has_bias=False - ).to_float(self.dense_dtype) + ) self.rel_pos_y_bias = nn.Dense( self.rel_2d_pos_onehot_size, config.num_attention_heads, has_bias=False - ).to_float(self.dense_dtype) + ) def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128): def test(relative_position, bidirectional=True, num_buckets=32, max_distance=128): @@ -506,11 +491,7 @@ def construct( class LayoutXLMPooler(nn.Cell): def __init__(self, config): super().__init__() - self.use_float16 = config.use_float16 - self.dense_dtype = mstype.float32 - if self.use_float16 is True: - self.dense_dtype = mstype.float16 - self.dense = nn.Dense(config.hidden_size, config.hidden_size).to_float(self.dense_dtype) + self.dense = nn.Dense(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def construct(self, hidden_states): diff --git a/mindocr/models/heads/kie_relationextraction_head.py b/mindocr/models/heads/kie_relationextraction_head.py index 581d005d6..23a7634a2 100644 --- a/mindocr/models/heads/kie_relationextraction_head.py +++ b/mindocr/models/heads/kie_relationextraction_head.py @@ -7,14 +7,13 @@ class BiaffineAttention(nn.Cell): """Implements a biaffine attention operator for binary relation classification.""" - def __init__(self, in_features, out_features, use_float16: bool = True): + def __init__(self, in_features, out_features): super(BiaffineAttention, self).__init__() - self.dense_dtype = float16 if use_float16 else float32 self.in_features = in_features self.out_features = out_features self.bilinear = nn.BiDense(in_features, in_features, out_features, has_bias=False) - self.linear = nn.Dense(2 * in_features, out_features).to_float(self.dense_dtype) + self.linear = nn.Dense(2 * in_features, out_features) def construct(self, x_1, x_2): return self.bilinear(x_1, x_2) + self.linear(ops.concat((x_1, x_2), axis=-1)) @@ -24,23 +23,22 @@ class REDecoder(nn.Cell): """ Decoder of relation extraction """ - def __init__(self, hidden_size=768, hidden_dropout_prob=0.1, use_float16: bool = True): + def __init__(self, hidden_size=768, hidden_dropout_prob=0.1): super(REDecoder, self).__init__() - self.dense_dtype = float16 if use_float16 else float32 self.entity_emb = nn.Embedding(3, hidden_size) self.ffnn_head = nn.SequentialCell( - nn.Dense(hidden_size * 2, hidden_size).to_float(self.dense_dtype), + nn.Dense(hidden_size * 2, hidden_size), nn.ReLU(), nn.Dropout(p=hidden_dropout_prob), - nn.Dense(hidden_size, hidden_size // 2).to_float(self.dense_dtype), + nn.Dense(hidden_size, hidden_size // 2), nn.ReLU(), nn.Dropout(p=hidden_dropout_prob), ) self.ffnn_tail = nn.SequentialCell( - nn.Dense(hidden_size * 2, hidden_size).to_float(self.dense_dtype), + nn.Dense(hidden_size * 2, hidden_size), nn.ReLU(), nn.Dropout(p=hidden_dropout_prob), - nn.Dense(hidden_size, hidden_size // 2).to_float(self.dense_dtype), + nn.Dense(hidden_size, hidden_size // 2), nn.ReLU(), nn.Dropout(p=hidden_dropout_prob), ) @@ -58,8 +56,8 @@ def construct(self, hidden_states, question, question_label, answer, answer_labe tmp_hidden_states = ops.gather_d(hidden_states, 1, answer) a_repr = ops.concat((tmp_hidden_states, a_label_repr), axis=-1) - q = self.ffnn_head(q_repr).astype(float32) - a = self.ffnn_tail(a_repr).astype(float32) + q = self.ffnn_head(q_repr) + a = self.ffnn_tail(a_repr) logits = self.rel_classifier(q, a) return logits @@ -68,13 +66,13 @@ class RelationExtractionHead(nn.Cell): """ Head of relation extraction tas """ - def __init__(self, use_visual_backbone: bool = True, use_float16: bool = False, dropout=None, **kwargs): + def __init__(self, use_visual_backbone: bool = True, dropout=None, **kwargs): super(RelationExtractionHead, self).__init__() - self.config = LayoutXLMPretrainedConfig(use_visual_backbone, use_float16) + self.config = LayoutXLMPretrainedConfig(use_visual_backbone) dropout_prob = dropout if dropout is not None else self.config.hidden_dropout_prob - self.extractor = REDecoder(self.config.hidden_size, dropout_prob, use_float16) + self.extractor = REDecoder(self.config.hidden_size, dropout_prob) self.dropout = nn.Dropout(p=dropout_prob) diff --git a/mindocr/models/heads/kie_tokenclassification_head.py b/mindocr/models/heads/kie_tokenclassification_head.py index aa0a59f25..0004ff3a8 100644 --- a/mindocr/models/heads/kie_tokenclassification_head.py +++ b/mindocr/models/heads/kie_tokenclassification_head.py @@ -9,19 +9,15 @@ def __init__( self, num_classes: int = 7, use_visual_backbone: bool = True, - use_float16: bool = False, dropout_prod=None, **kwargs ): super(TokenClassificationHead, self).__init__() self.num_classes = num_classes - dense_type = float32 - if use_float16 is True: - dense_type = float16 - self.config = LayoutXLMPretrainedConfig(use_visual_backbone, use_float16) + self.config = LayoutXLMPretrainedConfig(use_visual_backbone) dropout_prod = dropout_prod if dropout_prod is not None else self.config.hidden_dropout_prob self.dropout = nn.Dropout(p=dropout_prod) - self.classifier = nn.Dense(self.config.hidden_size, num_classes).to_float(dense_type) + self.classifier = nn.Dense(self.config.hidden_size, num_classes) @staticmethod def _init_weights(layer): diff --git a/mindocr/models/kie_layoutxlm.py b/mindocr/models/kie_layoutxlm.py index d5e4eee34..0f47364ab 100644 --- a/mindocr/models/kie_layoutxlm.py +++ b/mindocr/models/kie_layoutxlm.py @@ -44,7 +44,6 @@ def layoutxlm_ser( pretrained: bool = True, pretrained_backbone=False, use_visual_backbone: bool = True, - use_float16: bool = False, **kwargs ): model_config = { @@ -53,13 +52,11 @@ def layoutxlm_ser( "name": "layoutxlm", "pretrained": pretrained_backbone, # backbone pretrained "use_visual_backbone": use_visual_backbone, - "use_float16": use_float16, }, "head": { "name": "TokenClassificationHead", "num_classes": 7, "use_visual_backbone": use_visual_backbone, - "use_float16": use_float16, "dropout_prod": None, }, } @@ -72,20 +69,18 @@ def layoutxlm_ser( @register_model -def vi_layoutxlm_ser(pretrained: bool = True, use_visual_backbone: bool = False, use_float16: bool = False, **kwargs): +def vi_layoutxlm_ser(pretrained: bool = True, use_visual_backbone: bool = False, **kwargs): model_config = { "type": "kie", "backbone": { "name": "layoutxlm", "pretrained": pretrained, # backbone pretrained "use_visual_backbone": use_visual_backbone, - "use_float16": use_float16, }, "head": { "name": "TokenClassificationHead", "num_classes": 7, "use_visual_backbone": use_visual_backbone, - "use_float16": use_float16, "dropout_prod": None, }, }