Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rectify code of the LayoutLM series models and adjust it to amp_level mode #693

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions configs/kie/layoutlmv3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Apart from the dataset setting, please also check the following important args:
system:
mode:
distribute: False # `True` for distributed training, `False` for standalone training
amp_level: 'O0'
amp_level: 'O3'
seed: 42
val_while_train: True # Validate while training
drop_overflow_update: False
Expand All @@ -157,7 +157,6 @@ model:
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: True
use_float16: True
pretrained:
...
train:
Expand Down
3 changes: 1 addition & 2 deletions configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: "O0"
amp_level: "O3"
seed: 42
log_interval: 10
val_start_epoch: 50
Expand All @@ -17,7 +17,6 @@ model:
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: True
use_float16: True
pretrained:

postprocess:
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: 'O0'
amp_level: 'O3'
seed: 42
log_interval: 10
val_while_train: True
Expand All @@ -16,11 +16,9 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: True
use_float16: True
head:
name: RelationExtractionHead
use_visual_backbone: True
use_float16: True
pretrained:

postprocess:
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: 'O0'
amp_level: 'O3'
seed: 42
log_interval: 10
val_while_train: True
Expand All @@ -15,12 +15,10 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: True
use_float16: True
head :
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: True
use_float16: True
pretrained:

postprocess:
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/vi_layoutxlm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Apart from the dataset setting, please also check the following important args:
system:
mode:
distribute: False # `True` for distributed training, `False` for standalone training
amp_level: 'O0'
amp_level: 'O3'
seed: 42
val_while_train: True # Validate while training
drop_overflow_update: False
Expand All @@ -171,12 +171,10 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: False
use_float16: True
head :
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: False
use_float16: True
pretrained:
...
train:
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/vi_layoutxlm/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ eval:
system:
mode:
distribute: False # 分布式训练为True,单卡训练为False
amp_level: 'O0'
amp_level: 'O3'
seed: 42
val_while_train: True # 边训练边验证
drop_overflow_update: False
Expand All @@ -168,12 +168,10 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: False
use_float16: True
head :
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: False
use_float16: True
pretrained:
...
train:
Expand Down
6 changes: 2 additions & 4 deletions configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: "O0"
amp_level: "O3"
seed: 42
log_interval: 10
val_while_train: True
Expand All @@ -16,11 +16,9 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: False
use_float16: True
head:
name: RelationExtractionHead
use_visual_backbone: False
use_float16: True
pretrained:

postprocess:
Expand Down Expand Up @@ -90,11 +88,11 @@ train:
"bbox",
"attention_mask",
"token_type_ids",
"image",
"question",
"question_label",
"answer",
"answer_label",
"image",
"relation_label",
]
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: 'O0'
amp_level: 'O3'
seed: 42
log_interval: 10
val_while_train: True
Expand All @@ -15,12 +15,10 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: False
use_float16: True
head :
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: False
use_float16: True
pretrained:

postprocess:
Expand Down
2 changes: 1 addition & 1 deletion mindocr/losses/kie_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ def __init__(self, **kwargs):
super().__init__()
self.loss_fct = nn.CrossEntropyLoss()

def construct(self, predicts, attention_mask, labels):
def construct(self, predicts, labels):
loss = self.loss_fct(predicts.transpose(0, 2, 1), labels.astype(ms.int32))
return loss
3 changes: 1 addition & 2 deletions mindocr/models/backbones/layoutlmv3/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

@dataclass
class LayoutLMv3PretrainedConfig:
def __init__(self, use_float16=False):
def __init__(self):
pretrained_config = {
"use_float16": use_float16,
"fast_qkv": False,
"vocab_size": 250002,
"hidden_size": 768,
Expand Down
21 changes: 12 additions & 9 deletions mindocr/models/backbones/layoutlmv3/layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def construct(

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
attention_scores = attention_scores + attention_mask.astype(self.dense_dtype)
attention_scores = attention_scores + attention_mask.astype(attention_scores.dtype)

# Normalize the attention scores to probabilities.
# Use the trick of the CogView paper to stablize training
Expand Down Expand Up @@ -227,11 +227,8 @@ def __init__(self, config):
self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias
self.patch_size = config.patch_size
self.use_float16 = config.use_float16
self.dense_dtype = mstype.float32
if self.use_float16 is True:
self.dense_dtype = mstype.float16
self.min = finfo(self.dense_dtype)
self.float32_min = finfo(mstype.float32)
self.float16_min = finfo(mstype.float16)
self.out_channels = 1
self.use_visual_backbone = True

Expand Down Expand Up @@ -342,7 +339,13 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape, dtype
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # fp16 compatibility
extended_attention_mask = extended_attention_mask.astype(dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * self.min

if dtype == mstype.float32:
minimum = self.float32_min
elif dtype == mstype.float16:
minimum = self.float16_min

extended_attention_mask = (1.0 - extended_attention_mask) * minimum
return extended_attention_mask

def get_head_mask(self, head_mask, num_hidden_layers: int, is_attention_chunked: bool = False):
Expand Down Expand Up @@ -518,7 +521,7 @@ def construct(


@register_backbone
def layoutlmv3(use_float16: bool = True, **kwargs):
pretrained_config = LayoutLMv3PretrainedConfig(use_float16)
def layoutlmv3(**kwargs):
pretrained_config = LayoutLMv3PretrainedConfig()
model = LayoutLMv3Model(pretrained_config)
return model
4 changes: 1 addition & 3 deletions mindocr/models/backbones/layoutxlm/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

@dataclass
class LayoutXLMPretrainedConfig:
def __init__(self, use_visual_backbone=True, use_float16=False):
def __init__(self, use_visual_backbone=True):
pretrained_config = {
"use_visual_backbone": use_visual_backbone,
"use_float16": use_float16,
"attention_probs_dropout_prob": 0.1,
"use_visual_backbone": use_visual_backbone,
"use_float16": use_float16,
"bos_token_id": 0,
"coordinate_size": 128,
"eos_token_id": 2,
Expand Down
12 changes: 3 additions & 9 deletions mindocr/models/backbones/layoutxlm/layoutxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,12 @@ def __init__(self, config):
self.has_visual_segment_embedding = config.has_visual_segment_embedding
self.embeddings = LayoutXLMEmbeddings(config)
self.use_visual_backbone = config.use_visual_backbone
self.use_float16 = config.use_float16
self.dense_dtype = mstype.float32
if self.use_float16 is True:
self.dense_dtype = mstype.float16

if self.use_visual_backbone is True:
set_context(jit_syntax_level=0)
self.visual = VisualBackbone(config)
self.visual.freeze()
self.visual_proj = nn.Dense(config.image_feature_pool_shape[-1], config.hidden_size).to_float(
self.dense_dtype
)
self.visual_proj = nn.Dense(config.image_feature_pool_shape[-1], config.hidden_size)
if self.has_visual_segment_embedding:
self.visual_segment_embedding = Parameter(nn.Embedding(1, config.hidden_size).embedding_table[0])
self.visual_LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)
Expand Down Expand Up @@ -302,8 +296,8 @@ def construct(


@register_backbone
def layoutxlm(pretrained: bool = True, use_visual_backbone: bool = True, use_float16: bool = False, **kwargs):
pretrained_config = LayoutXLMPretrainedConfig(use_visual_backbone, use_float16)
def layoutxlm(pretrained: bool = True, use_visual_backbone: bool = True, **kwargs):
pretrained_config = LayoutXLMPretrainedConfig(use_visual_backbone)
model = LayoutXLMModel(pretrained_config)
if pretrained:
if use_visual_backbone is True:
Expand Down
2 changes: 1 addition & 1 deletion mindocr/models/backbones/layoutxlm/visual_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def construct(self, x):
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
else:
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
results.extend(self.top_block(top_block_in_feature.astype(ms.float16)))
results.extend(self.top_block(top_block_in_feature))

assert len(self._out_features) == len(results)

Expand Down
Loading
Loading