diff --git a/configs/rec/can/README_CN.md b/configs/rec/can/README_CN.md
new file mode 100644
index 000000000..b6354d400
--- /dev/null
+++ b/configs/rec/can/README_CN.md
@@ -0,0 +1,51 @@
+[English]() | 中文
+
+# CAN (Counting-Aware Network)
+
+
+> [CAN: When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/pdf/2207.11463.pdf)
+
+## 1. 模型描述
+
+
+CAN是具有一个弱监督计数模块的注意力机制编码器-解码器手写数学公式识别算法。本文作者通过对现有的大部分手写数学公式识别算法研究,发现其基本采用基于注意力机制的编码器-解码器结构。该结构可使模型在识别每一个符号时,注意到图像中该符号对应的位置区域,在识别常规文本时,注意力的移动规律比较单一(通常为从左至右或从右至左),该机制在此场景下可靠性较高。然而在识别数学公式时,注意力在图像中的移动具有更多的可能性。因此,模型在解码较复杂的数学公式时,容易出现注意力不准确的现象,导致重复识别某符号或者是漏识别某符号。
+
+针对于此,作者设计了一个弱监督计数模块,该模块可以在没有符号级位置注释的情况下预测每个符号类的数量,然后将其插入到典型的基于注意的HMER编解码器模型中。这种做法主要基于以下两方面的考虑:1、符号计数可以隐式地提供符号位置信息,这种位置信息可以使得注意力更加准确。2、符号计数结果可以作为额外的全局信息来提升公式识别的准确率。
+
+
+
+
+
+ 图1. 手写数学公式识别算法对比 [1]
+
+
+CAN模型由主干特征提取网络、多尺度计数模块(MSCM)和结合计数的注意力解码器(CCAD)构成。主干特征提取通过采用DenseNet得到特征图,并将特征图输入MSCM,得到一个计数向量(Counting Vector),该计数向量的维度为1*C,C即公式词表大小,然后把这个计数向量和特征图一起输入到CCAD中,最终输出公式的latex。
+
+
+
+
+
+ 图2. 整体模型结构 [1]
+
+
+多尺度计数模MSCM块旨在预测每个符号类别的数量,其由多尺度特征提取、通道注意力和池化算子组成。由于书写习惯的不同,公式图像通常包含各种大小的符号。单一卷积核大小无法有效处理尺度变化。为此,首先利用了两个并行卷积分支通过使用不同的内核大小(设置为 3×3 和 5×5)来提取多尺度特征。在卷积层之后,采用通道注意力来进一步增强特征信息。
+
+
+
+
+
+ 图3. MSCM多尺度计数模块 [1]
+
+
+结合计数的注意力解码器:为了加强模型对于空间位置的感知,使用位置编码表征特征图中不同空间位置。另外,不同于之前大部分公式识别方法只使用局部特征进行符号预测的做法,在进行符号类别预测时引入符号计数结果作为额外的全局信息来提升识别准确率。
+
+
+
+
+
+ 图4. 结合计数的注意力解码器CCAD [1]
+
+
+## 参考文献
+
+[1] Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne Zhang. RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition. arXiv:2007.07542, ECCV'2020
diff --git a/mindocr/models/backbones/__init__.py b/mindocr/models/backbones/__init__.py
index 8f71299ab..70132195c 100644
--- a/mindocr/models/backbones/__init__.py
+++ b/mindocr/models/backbones/__init__.py
@@ -19,6 +19,7 @@
from .rec_vgg import *
from .table_master_resnet import *
from .yolov8_backbone import yolov8_backbone
+from .rec_can_densenet import *
__all__ = []
__all__.extend(builder.__all__)
diff --git a/mindocr/models/backbones/rec_can_densenet.py b/mindocr/models/backbones/rec_can_densenet.py
new file mode 100644
index 000000000..fb2c06193
--- /dev/null
+++ b/mindocr/models/backbones/rec_can_densenet.py
@@ -0,0 +1,197 @@
+"""
+Rec_DenseNet model
+"""
+import math
+import mindspore as ms
+
+from mindspore import nn
+from mindspore import ops
+from ._registry import register_backbone, register_backbone_class
+
+ms.set_context(pynative_synchronize=True)
+
+__all__ = ['DenseNet']
+
+
+class Bottleneck(nn.Cell):
+ """Bottleneck block of rec_densenet"""
+ def __init__(self, n_channels, growth_rate, use_dropout):
+ super().__init__()
+ inter_channels = 4 * growth_rate
+ self.bn1 = nn.BatchNorm2d(inter_channels)
+ self.conv1 = nn.Conv2d(
+ n_channels,
+ inter_channels,
+ kernel_size=1,
+ has_bias=False,
+ pad_mode='pad',
+ padding=0,
+ )
+ self.bn2 = nn.BatchNorm2d(growth_rate)
+ self.conv2 = nn.Conv2d(
+ inter_channels,
+ growth_rate,
+ kernel_size=3,
+ has_bias=False,
+ pad_mode='pad',
+ padding=1
+ )
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def construct(self, x):
+ out = ops.relu(self.bn1(self.conv1(x)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = ops.relu(self.bn2(self.conv2(out)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = ops.concat((x, out), 1)
+ return out
+
+
+class SingleLayer(nn.Cell):
+ """SingleLayer block of rec_densenet"""
+ def __init__(self, n_channels, growth_rate, use_dropout):
+ super().__init__()
+ self.bn1 = nn.BatchNorm2d(n_channels)
+ self.conv1 = nn.Conv2d(
+ n_channels,
+ growth_rate,
+ kernel_size=3,
+ has_bias=False,
+ pad_mode='pad',
+ padding=1
+ )
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def construct(self, x):
+ out = self.conv1(ops.relu(x))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = ops.concat((x, out), 1)
+ return out
+
+
+class Transition(nn.Cell):
+ """Transition Module of rec_densenet"""
+ def __init__(self, n_channels, out_channels, use_dropout):
+ super().__init__()
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.conv1 = nn.Conv2d(
+ n_channels,
+ out_channels,
+ kernel_size=1,
+ has_bias=False
+ )
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+
+ def construct(self, x):
+ out = ops.relu(self.bn1(self.conv1(x)))
+
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = ops.avg_pool2d(out, 2, stride=2, ceil_mode=True)
+ return out
+
+
+@register_backbone_class
+class DenseNet(nn.Cell):
+ r"""The RecDenseNet model is the customized DenseNet backbone for
+ Handwritten Mathematical Expression Recognition.
+ For example, in the CAN recognition algorithm, it is used in
+ feature extraction to obtain a formula feature map.
+ DenseNet Network is based on
+ `"When Counting Meets HMER: Counting-Aware Network for
+ Handwritten Mathematical Expression Recognition"
+ `_ paper.
+
+ Args:
+ growth_rate (int): growth rate of DenseNet. The default value is 24.
+ reduction (float): compression ratio in DenseNet. The default is 0.5.
+ bottleneck (bool): specifies whether to use a bottleneck layer. The default is True.
+ use_dropout (bool): indicates whether to use dropout. The default is True.
+ input_channels (int): indicates the number of channels in the input image. The default is 3.
+ Return:
+ nn.Cell for backbone module
+
+ Example:
+ >>> # init a DenseNet network
+ >>> params = {
+ >>> 'growth_rate': 24,
+ >>> 'reduction': 0.5,
+ >>> 'bottleneck': True,
+ >>> 'use_dropout': True,
+ >>> 'input_channels': 3,
+ >>> }
+ >>> model = DenseNet(**params)
+ """
+ def __init__(self, growth_rate, reduction, bottleneck, use_dropout, input_channels):
+ super().__init__()
+ n_dense_blocks = 16
+ n_channels = 2 * growth_rate
+
+ self.conv1 = nn.Conv2d(
+ input_channels,
+ n_channels,
+ kernel_size=7,
+ stride=2,
+ has_bias=False,
+ pad_mode='pad',
+ padding=3,
+ )
+ self.dense1 = self.make_dense(
+ n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
+ )
+ n_channels += n_dense_blocks * growth_rate
+ out_channels = int(math.floor(n_channels * reduction))
+ self.trans1 = Transition(n_channels, out_channels, use_dropout)
+
+ n_channels = out_channels
+ self.dense2 = self.make_dense(
+ n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
+ )
+ n_channels += n_dense_blocks * growth_rate
+ out_channels = int(math.floor(n_channels * reduction))
+ self.trans2 = Transition(n_channels, out_channels, use_dropout)
+
+ n_channels = out_channels
+ self.dense3 = self.make_dense(
+ n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
+ )
+ n_channels += n_dense_blocks * growth_rate
+ self.out_channels = [n_channels]
+
+ def construct(self, x):
+ out = self.conv1(x)
+ out = ops.relu(out)
+ out = ops.max_pool2d(out, 2, ceil_mode=True)
+ out = self.dense1(out)
+ out = self.trans1(out)
+ out = self.dense2(out)
+ out = self.trans2(out)
+ out = self.dense3(out)
+ return out
+
+ def make_dense(self, n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout):
+ """Create dense_layer of DenseNet"""
+ layers = []
+ layer_constructor = Bottleneck if bottleneck else SingleLayer
+ for _ in range(int(n_dense_blocks)):
+ layers.append(layer_constructor(n_channels, growth_rate, use_dropout))
+ n_channels += growth_rate
+ return nn.SequentialCell(*layers)
+
+
+@register_backbone
+def rec_can_densenet(pretrained: bool = False, **kwargs) -> DenseNet:
+ """Create a rec_densenet backbone model."""
+ if pretrained is True:
+ raise NotImplementedError(
+ "The default pretrained checkpoint for `rec_densenet` backbone does not exist."
+ )
+
+ model = DenseNet(**kwargs)
+ return model
diff --git a/mindocr/models/heads/builder.py b/mindocr/models/heads/builder.py
index 286fc997f..bfb477e71 100644
--- a/mindocr/models/heads/builder.py
+++ b/mindocr/models/heads/builder.py
@@ -18,6 +18,7 @@
'YOLOv8Head',
'MultiHead',
'TableMasterHead',
+ 'CANHead',
]
from .cls_head import MobileNetV3Head
from .conv_head import ConvHead
@@ -36,6 +37,7 @@
from .rec_visionlan_head import VisionLANHead
from .table_master_head import TableMasterHead
from .yolov8_head import YOLOv8Head
+from .rec_can_head import CANHead
def build_head(head_name, **kwargs):
diff --git a/mindocr/models/heads/rec_can_head.py b/mindocr/models/heads/rec_can_head.py
new file mode 100644
index 000000000..4ae53921a
--- /dev/null
+++ b/mindocr/models/heads/rec_can_head.py
@@ -0,0 +1,386 @@
+"""
+CAN_HEAD_MODULE
+"""
+import math
+import mindspore as ms
+from mindspore import nn
+from mindspore import ops
+
+ms.set_context(mode=ms.PYNATIVE_MODE, pynative_synchronize=True)
+
+
+class ChannelAtt(nn.Cell):
+ """Channel Attention of the Counting Module"""
+ def __init__(self, channel, reduction):
+ super().__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.SequentialCell([
+ nn.Dense(channel, channel // reduction),
+ nn.ReLU(),
+ nn.Dense(channel // reduction, channel),
+ nn.Sigmoid()
+ ])
+
+ def construct(self, x):
+ b, c, _, _ = x.shape
+ y = ops.reshape(self.avg_pool(x), (b, c))
+ y = ops.reshape(self.fc(y), (b, c, 1, 1))
+ return x * y
+
+
+class CountingDecoder(nn.Cell):
+ """Single Counting Module"""
+ def __init__(self, in_channel, out_channel, kernel_size):
+ super().__init__()
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ self.trans_layer = nn.SequentialCell([
+ nn.Conv2d(
+ self.in_channel,
+ 512,
+ kernel_size=kernel_size,
+ pad_mode='pad',
+ padding=kernel_size // 2,
+ has_bias=False,
+ ),
+ nn.BatchNorm2d(512)
+ ])
+
+ self.channel_att = ChannelAtt(512, 16)
+
+ self.pred_layer = nn.SequentialCell([
+ nn.Conv2d(
+ 512,
+ self.out_channel,
+ kernel_size=1,
+ has_bias=False,
+ ),
+ nn.Sigmoid()
+ ])
+
+ def construct(self, x, mask):
+ b, _, h, w = x.shape
+ x = self.trans_layer(x)
+ x = self.channel_att(x)
+ x = self.pred_layer(x)
+
+ if mask is not None:
+ x = x * mask
+ x = ops.reshape(x, (b, self.out_channel, -1))
+ x1 = ops.sum(x, -1)
+
+ return x1, ops.reshape(x, (b, self.out_channel, h, w))
+
+
+class Attention(nn.Cell):
+ """Attention Module"""
+ def __init__(self, hidden_size, attention_dim):
+ super().__init__()
+ self.hidden = hidden_size
+ self.attention_dim = attention_dim
+ self.hidden_weight = nn.Dense(self.hidden, self.attention_dim)
+ self.attention_conv = nn.Conv2d(
+ 1,
+ 512,
+ kernel_size=11,
+ pad_mode='pad',
+ padding=5,
+ has_bias=False
+ )
+ self.attention_weight = nn.Dense(512, self.attention_dim, has_bias=False)
+ self.alpha_convert = nn.Dense(self.attention_dim, 1)
+
+ def construct(
+ self, cnn_features, cnn_features_trans, hidden, alpha_sum, image_mask=None
+ ):
+ query = self.hidden_weight(hidden)
+ alpha_sum_trans = self.attention_conv(alpha_sum)
+ coverage_alpha = self.attention_weight(alpha_sum_trans.permute(0, 2, 3, 1))
+ query_expanded = ops.unsqueeze(ops.unsqueeze(query, 1), 2)
+ alpha_score = ops.tanh(
+ query_expanded
+ + coverage_alpha
+ + cnn_features_trans.permute(0, 2, 3, 1)
+ )
+ energy = self.alpha_convert(alpha_score)
+ energy = energy - energy.max()
+ energy_exp = ops.exp(energy.squeeze(-1))
+ if image_mask is not None:
+ energy_exp = energy_exp * image_mask.squeeze(1)
+ alpha = energy_exp / (
+ ops.unsqueeze(ops.unsqueeze(ops.sum(ops.sum(energy_exp, -1), -1), 1), 2) + 1e-10
+ )
+ alpha_sum = ops.unsqueeze(alpha, 1) + alpha_sum
+ context_vector = ops.sum(
+ ops.sum((ops.unsqueeze(alpha, 1) * cnn_features), -1), -1
+ )
+
+ return context_vector, alpha, alpha_sum
+
+
+class PositionEmbeddingSine(nn.Cell):
+ """Position Embedding Sine Module of the Attention Decoder"""
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True when scale is provided")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def construct(self, x, mask):
+ y_embed = ops.cumsum(mask, 1, dtype=ms.float32)
+ x_embed = ops.cumsum(mask, 2, dtype=ms.float32)
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = ops.arange(self.num_pos_feats, dtype=ms.float32)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = ops.unsqueeze(x_embed, 3) / dim_t
+ pos_y = ops.unsqueeze(y_embed, 3) / dim_t
+
+ pos_x = ops.flatten(
+ ops.stack(
+ [ops.sin(pos_x[:, :, :, 0::2]), ops.cos(pos_x[:, :, :, 1::2])],
+ axis=4,
+ ),
+ 'C',
+ start_dim=3,
+ )
+ pos_y = ops.flatten(
+ ops.stack(
+ [ops.sin(pos_y[:, :, :, 0::2]), ops.cos(pos_y[:, :, :, 1::2])],
+ axis=4,
+ ),
+ 'C',
+ start_dim=3,
+ )
+
+ pos = ops.concat([pos_x, pos_y], axis=3)
+ pos = ops.transpose(pos, (0, 3, 1, 2))
+ return pos
+
+
+class AttDecoder(nn.Cell):
+ """Attention Decoder Module"""
+ def __init__(
+ self,
+ ratio,
+ is_train,
+ input_size,
+ hidden_size,
+ encoder_out_channel,
+ dropout,
+ dropout_ratio,
+ word_num,
+ counting_decoder_out_channel,
+ attention,
+ ):
+ super().__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.out_channel = encoder_out_channel
+ self.attention_dim = attention["attention_dim"]
+ self.dropout_prob = dropout
+ self.ratio = ratio
+ self.word_num = word_num
+ self.counting_num = counting_decoder_out_channel
+ self.is_train = is_train
+
+ self.init_weight = nn.Dense(self.out_channel, self.hidden_size)
+ self.embedding = nn.Embedding(self.word_num, self.input_size)
+ self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
+ self.word_attention = Attention(self.hidden_size, self.attention_dim)
+
+ self.encoder_feature_conv = nn.Conv2d(
+ self.out_channel,
+ self.attention_dim,
+ kernel_size=attention["word_conv_kernel"],
+ pad_mode="pad",
+ padding=attention["word_conv_kernel"] // 2,
+ )
+
+ self.word_state_weight = nn.Dense(self.hidden_size, self.hidden_size)
+ self.word_embedding_weight = nn.Dense(self.input_size, self.hidden_size)
+ self.word_context_weight = nn.Dense(self.out_channel, self.hidden_size)
+ self.counting_context_weight = nn.Dense(self.counting_num, self.hidden_size)
+ self.word_convert = nn.Dense(self.hidden_size, self.word_num)
+
+ if dropout:
+ self.dropout = nn.Dropout(p=dropout_ratio)
+
+ def construct(self, cnn_features, labels, counting_preds, images_mask, is_train=True):
+ if is_train:
+ _, num_steps = labels.shape
+ else:
+ num_steps = 36
+
+ batch_size, _, height, width = cnn_features.shape
+ images_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
+
+ word_probs = ops.zeros((batch_size, num_steps, self.word_num))
+ word_alpha_sum = ops.zeros((batch_size, 1, height, width))
+
+ hidden = self.init_hidden(cnn_features, images_mask)
+ counting_context_weighted = self.counting_context_weight(counting_preds)
+ cnn_features_trans = self.encoder_feature_conv(cnn_features)
+
+ position_embedding = PositionEmbeddingSine(256, normalize=True)
+ pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])
+ cnn_features_trans = cnn_features_trans + pos
+
+ word = ops.ones((batch_size, 1), dtype=ms.int64)
+ word = ops.squeeze(word, axis=1)
+
+ for i in range(num_steps):
+ word_embedding = self.embedding(word)
+ hidden = self.word_input_gru(word_embedding, hidden)
+ word_context_vec, _, word_alpha_sum = self.word_attention(
+ cnn_features,
+ cnn_features_trans,
+ hidden,
+ word_alpha_sum,
+ images_mask
+ )
+
+ current_state = self.word_state_weight(hidden)
+ word_weight_embedding = self.word_embedding_weight(word_embedding)
+ word_context_weighted = self.word_context_weight(word_context_vec)
+
+ if self.dropout_prob:
+ word_out_state = self.dropout(
+ current_state
+ + word_weight_embedding
+ + word_context_weighted
+ + counting_context_weighted
+ )
+ else:
+ word_out_state = (
+ current_state
+ + word_weight_embedding
+ + word_context_weighted
+ + counting_context_weighted
+ )
+
+ word_prob = self.word_convert(word_out_state)
+ word_probs[:, i] = word_prob
+
+ if self.is_train:
+ word = labels[:, i]
+ else:
+ word = word_prob.argmax(1)
+ word = ops.multiply(
+ word, labels[:, i]
+ )
+
+ return word_probs
+
+ def init_hidden(self, features, feature_mask):
+ """Used to initialize the hidden layer"""
+ average = ops.sum(
+ ops.sum(features * feature_mask, dim=-1), dim=-1
+ ) / ops.sum((ops.sum(feature_mask, dim=-1)), dim=-1)
+ average = self.init_weight(average)
+ return ops.tanh(average)
+
+
+class CANHead(nn.Cell):
+ r"""The CAN model is an algorithm used to recognize
+ handwritten mathematical formulas.
+ CAN Network is based on
+ `"When Counting Meets HMER: Counting-Aware Network
+ for Handwritten Mathematical Expression Recognition"
+ `_ paper.
+
+ Args:
+ "in_channels": number of channels for the input feature.
+ "out_channels": number of channels for the output feature.
+ "ratio": the ratio used to downsample the feature map.
+ "attdecoder", the parameters needed to build an AttDecoder:
+ - "is_train": indicates whether the model is in training mode.
+ - "input_size":eEnter the size.
+ - "hidden_size": Hidden layer size.
+ - "encoder_out_channel": number of channels for the encoder output feature.
+ - "dropout": whether to use dropout.
+ - "dropout_ratio": the ratio of dropout.
+ - "word_num": number of words.
+ - "counting_decoder_out_channel": counts the decoder's output channels.
+ - "attention", the parameters needed to build an attention mechanism:
+ - "attention_dim": the dimension of the attention mechanism.
+ - "word_conv_kernel": the size of the lexical convolution kernel.
+
+ Return:
+ "word_probs": word probability distribution.
+ "counting_preds1": count prediction 1, the number of words
+ predicted by the 3*3 convolution kernel.
+ "counting_preds2": count prediction 2, the number of words
+ predicted by the 5*5 convolution kernel.
+ "counting_preds": the mean predicted by the above two counts.
+
+
+ Example:
+ >>> # init a CANHead network
+ >>> in_channels = 684
+ >>> out_channels = 111
+ >>> ratio = 16
+ >>> attdecoder_params = {
+ >>> 'is_train': True,
+ >>> 'input_size': 256,
+ >>> 'hidden_size': 256,
+ >>> 'encoder_out_channel': in_channels,
+ >>> 'dropout': True,
+ >>> 'dropout_ratio': 0.5,
+ >>> 'word_num': 111,
+ >>> 'counting_decoder_out_channel': out_channels,
+ >>> 'attention': {
+ >>> 'attention_dim': 512,
+ >>> 'word_conv_kernel': 1
+ >>> }
+ >>> }
+ >>> model = CANHead(in_channels, out_channels, ratio, attdecoder_params)
+ """
+ def __init__(self, in_channels, out_channels, ratio, attdecoder):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.counting_decoder1 = CountingDecoder(
+ self.in_channels, self.out_channels, 3
+ )
+ self.counting_decoder2 = CountingDecoder(
+ self.in_channels, self.out_channels, 5
+ )
+
+ self.decoder = AttDecoder(ratio, **attdecoder)
+
+ self.ratio = ratio
+
+ def construct(self, x, *args):
+ cnn_features = x
+ images_mask = args[0][0]
+ labels = args[0][1]
+
+ counting_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
+ counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
+ counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
+ counting_preds = (counting_preds1 + counting_preds2) / 2
+
+ word_probs = self.decoder(cnn_features, labels, counting_preds, images_mask)
+
+ return {
+ 'word_probs': word_probs,
+ 'counting_preds': counting_preds,
+ 'counting_preds1': counting_preds1,
+ 'counting_preds2': counting_preds2
+ }
diff --git a/tests/ut/test_can_model.py b/tests/ut/test_can_model.py
new file mode 100644
index 000000000..5eba1c368
--- /dev/null
+++ b/tests/ut/test_can_model.py
@@ -0,0 +1,97 @@
+import sys
+import mindocr
+import mindspore as ms
+
+from mindocr.models.base_model import BaseModel
+from mindocr.models.backbones import build_backbone
+from mindocr.models.heads import build_head
+from mindspore import ops
+
+sys.path.append(".")
+ms.set_context(mode=ms.PYNATIVE_MODE, pynative_synchronize=True)
+
+if __name__ == "__main__":
+ # model parameter setting
+ model_config = {
+ "backbone": {
+ "name": "rec_can_densenet",
+ "pretrained": False,
+ "growth_rate": 24,
+ "reduction": 0.5,
+ "bottleneck": True,
+ "use_dropout": True,
+ "input_channels": 3,
+ },
+ "neck": {
+ },
+ "head": {
+ "name": "CANHead",
+ "out_channels": 111,
+ "ratio": 16,
+ "attdecoder": {
+ "is_train": True,
+ "input_size": 256,
+ "hidden_size": 256,
+ "encoder_out_channel": 684,
+ "dropout": True,
+ "dropout_ratio": 0.5,
+ "word_num": 111,
+ "counting_decoder_out_channel": 111,
+ "attention": {
+ "attention_dim": 512,
+ "word_conv_kernel": 1,
+ },
+ },
+ },
+ }
+
+
+ # test case parameter settings
+ batch_size = 1
+ input_tensor_channel = 3
+ images_mask_channel = 1
+ num_steps = 10
+ word_num = 111
+ out_channels = 111
+ h = 256
+ w = 256
+
+ input_tensor = ops.randn((batch_size, input_tensor_channel, h, w))
+ images_mask = ops.ones((batch_size, images_mask_channel, h, w))
+ labels = ops.randint(low=0, high=word_num, size=(batch_size, num_steps))
+
+
+ # basemodel unit test
+ model_config.pop("neck")
+ model = BaseModel(model_config)
+ hout = model(input_tensor, images_mask, labels)
+
+ assert hout["word_probs"].shape == (batch_size, num_steps, word_num), "Word probabilities shape is incorrect"
+ assert hout["counting_preds"].shape == (batch_size, out_channels), "Counting predictions shape is incorrect"
+ assert hout["counting_preds1"].shape == (batch_size, out_channels), "Counting predictions 1 shape is incorrect"
+ assert hout["counting_preds2"].shape == (batch_size, out_channels), "Counting predictions 2 shape is incorrect"
+
+
+ # build_backbone unit test
+ backbone_name = model_config["backbone"].pop("name")
+ backbone = build_backbone(backbone_name, **model_config["backbone"])
+ bout = backbone(input_tensor)
+
+
+ bout_c = backbone.out_channels[-1] #The paper specified 684 features to be extracted
+ bout_h = h/model_config["head"]["ratio"]
+ bout_w = w/model_config["head"]["ratio"]
+ assert bout_c == 684, "bout channel is incorrect"
+ assert bout.shape == (batch_size, bout_c, bout_h, bout_w), "bout shape is incorrect"
+
+
+ # build_head unit test
+ head_name = model_config["head"].pop("name")
+ head = build_head(head_name, in_channels=bout_c, **model_config["head"])
+ head_args = ((images_mask, labels))
+ hout = head(bout, head_args)
+
+ assert hout["word_probs"].shape == (batch_size, num_steps, word_num), "Word probabilities shape is incorrect"
+ assert hout["counting_preds"].shape == (batch_size, out_channels), "Counting predictions shape is incorrect"
+ assert hout["counting_preds1"].shape == (batch_size, out_channels), "Counting predictions 1 shape is incorrect"
+ assert hout["counting_preds2"].shape == (batch_size, out_channels), "Counting predictions 2 shape is incorrect"