diff --git a/configs/rec/can/README_CN.md b/configs/rec/can/README_CN.md
new file mode 100644
index 000000000..b6354d400
--- /dev/null
+++ b/configs/rec/can/README_CN.md
@@ -0,0 +1,51 @@
+[English]() | 中文
+# CAN (Counting-Aware Network)
+> [CAN: When Counting Meets HMER: Counting-Aware Network for Handwritten Mathematical Expression Recognition](https://arxiv.org/pdf/2207.11463.pdf)
+## 1. 模型描述
+ 图1. 手写数学公式识别算法对比 [1]
+CAN模型由主干特征提取网络、多尺度计数模块(MSCM)和结合计数的注意力解码器(CCAD)构成。主干特征提取通过采用DenseNet得到特征图,并将特征图输入MSCM,得到一个计数向量(Counting Vector),该计数向量的维度为1*C,C即公式词表大小,然后把这个计数向量和特征图一起输入到CCAD中,最终输出公式的latex。
+ 图2. 整体模型结构 [1]
+多尺度计数模MSCM块旨在预测每个符号类别的数量,其由多尺度特征提取、通道注意力和池化算子组成。由于书写习惯的不同,公式图像通常包含各种大小的符号。单一卷积核大小无法有效处理尺度变化。为此,首先利用了两个并行卷积分支通过使用不同的内核大小(设置为 3×3 和 5×5)来提取多尺度特征。在卷积层之后,采用通道注意力来进一步增强特征信息。
+ 图3. MSCM多尺度计数模块 [1]
+ 图4. 结合计数的注意力解码器CCAD [1]
+## 参考文献
+[1] Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne Zhang. RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition. arXiv:2007.07542, ECCV'2020
diff --git a/mindocr/models/backbones/__init__.py b/mindocr/models/backbones/__init__.py
index 8f71299ab..70132195c 100644
--- a/mindocr/models/backbones/__init__.py
+++ b/mindocr/models/backbones/__init__.py
@@ -19,6 +19,7 @@
from .rec_vgg import *
from .table_master_resnet import *
from .yolov8_backbone import yolov8_backbone
+from .rec_can_densenet import *
__all__ = []
diff --git a/mindocr/models/backbones/rec_can_densenet.py b/mindocr/models/backbones/rec_can_densenet.py
new file mode 100644
index 000000000..fb2c06193
--- /dev/null
+++ b/mindocr/models/backbones/rec_can_densenet.py
@@ -0,0 +1,197 @@
+Rec_DenseNet model
+import math
+import mindspore as ms
+from mindspore import nn
+from mindspore import ops
+from ._registry import register_backbone, register_backbone_class
+__all__ = ['DenseNet']
+class Bottleneck(nn.Cell):
+ """Bottleneck block of rec_densenet"""
+ def __init__(self, n_channels, growth_rate, use_dropout):
+ super().__init__()
+ inter_channels = 4 * growth_rate
+ self.bn1 = nn.BatchNorm2d(inter_channels)
+ self.conv1 = nn.Conv2d(
+ n_channels,
+ inter_channels,
+ kernel_size=1,
+ has_bias=False,
+ pad_mode='pad',
+ padding=0,
+ )
+ self.bn2 = nn.BatchNorm2d(growth_rate)
+ self.conv2 = nn.Conv2d(
+ inter_channels,
+ growth_rate,
+ kernel_size=3,
+ has_bias=False,
+ pad_mode='pad',
+ padding=1
+ )
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+ def construct(self, x):
+ out = ops.relu(self.bn1(self.conv1(x)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = ops.relu(self.bn2(self.conv2(out)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = ops.concat((x, out), 1)
+ return out
+class SingleLayer(nn.Cell):
+ """SingleLayer block of rec_densenet"""
+ def __init__(self, n_channels, growth_rate, use_dropout):
+ super().__init__()
+ self.bn1 = nn.BatchNorm2d(n_channels)
+ self.conv1 = nn.Conv2d(
+ n_channels,
+ growth_rate,
+ kernel_size=3,
+ has_bias=False,
+ pad_mode='pad',
+ padding=1
+ )
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+ def construct(self, x):
+ out = self.conv1(ops.relu(x))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = ops.concat((x, out), 1)
+ return out
+class Transition(nn.Cell):
+ """Transition Module of rec_densenet"""
+ def __init__(self, n_channels, out_channels, use_dropout):
+ super().__init__()
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.conv1 = nn.Conv2d(
+ n_channels,
+ out_channels,
+ kernel_size=1,
+ has_bias=False
+ )
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(p=0.2)
+ def construct(self, x):
+ out = ops.relu(self.bn1(self.conv1(x)))
+ if self.use_dropout:
+ out = self.dropout(out)
+ out = ops.avg_pool2d(out, 2, stride=2, ceil_mode=True)
+ return out
+class DenseNet(nn.Cell):
+ r"""The RecDenseNet model is the customized DenseNet backbone for
+ Handwritten Mathematical Expression Recognition.
+ For example, in the CAN recognition algorithm, it is used in
+ feature extraction to obtain a formula feature map.
+ DenseNet Network is based on
+ `"When Counting Meets HMER: Counting-Aware Network for
+ Handwritten Mathematical Expression Recognition"
+ `_ paper.
+ Args:
+ growth_rate (int): growth rate of DenseNet. The default value is 24.
+ reduction (float): compression ratio in DenseNet. The default is 0.5.
+ bottleneck (bool): specifies whether to use a bottleneck layer. The default is True.
+ use_dropout (bool): indicates whether to use dropout. The default is True.
+ input_channels (int): indicates the number of channels in the input image. The default is 3.
+ Return:
+ nn.Cell for backbone module
+ Example:
+ >>> # init a DenseNet network
+ >>> params = {
+ >>> 'growth_rate': 24,
+ >>> 'reduction': 0.5,
+ >>> 'bottleneck': True,
+ >>> 'use_dropout': True,
+ >>> 'input_channels': 3,
+ >>> }
+ >>> model = DenseNet(**params)
+ """
+ def __init__(self, growth_rate, reduction, bottleneck, use_dropout, input_channels):
+ super().__init__()
+ n_dense_blocks = 16
+ n_channels = 2 * growth_rate
+ self.conv1 = nn.Conv2d(
+ input_channels,
+ n_channels,
+ kernel_size=7,
+ stride=2,
+ has_bias=False,
+ pad_mode='pad',
+ padding=3,
+ )
+ self.dense1 = self.make_dense(
+ n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
+ )
+ n_channels += n_dense_blocks * growth_rate
+ out_channels = int(math.floor(n_channels * reduction))
+ self.trans1 = Transition(n_channels, out_channels, use_dropout)
+ n_channels = out_channels
+ self.dense2 = self.make_dense(
+ n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
+ )
+ n_channels += n_dense_blocks * growth_rate
+ out_channels = int(math.floor(n_channels * reduction))
+ self.trans2 = Transition(n_channels, out_channels, use_dropout)
+ n_channels = out_channels
+ self.dense3 = self.make_dense(
+ n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout
+ )
+ n_channels += n_dense_blocks * growth_rate
+ self.out_channels = [n_channels]
+ def construct(self, x):
+ out = self.conv1(x)
+ out = ops.relu(out)
+ out = ops.max_pool2d(out, 2, ceil_mode=True)
+ out = self.dense1(out)
+ out = self.trans1(out)
+ out = self.dense2(out)
+ out = self.trans2(out)
+ out = self.dense3(out)
+ return out
+ def make_dense(self, n_channels, growth_rate, n_dense_blocks, bottleneck, use_dropout):
+ """Create dense_layer of DenseNet"""
+ layers = []
+ layer_constructor = Bottleneck if bottleneck else SingleLayer
+ for _ in range(int(n_dense_blocks)):
+ layers.append(layer_constructor(n_channels, growth_rate, use_dropout))
+ n_channels += growth_rate
+ return nn.SequentialCell(*layers)
+def rec_can_densenet(pretrained: bool = False, **kwargs) -> DenseNet:
+ """Create a rec_densenet backbone model."""
+ if pretrained is True:
+ raise NotImplementedError(
+ "The default pretrained checkpoint for `rec_densenet` backbone does not exist."
+ )
+ model = DenseNet(**kwargs)
+ return model
diff --git a/mindocr/models/heads/builder.py b/mindocr/models/heads/builder.py
index 286fc997f..bfb477e71 100644
--- a/mindocr/models/heads/builder.py
+++ b/mindocr/models/heads/builder.py
@@ -18,6 +18,7 @@
+ 'CANHead',
from .cls_head import MobileNetV3Head
from .conv_head import ConvHead
@@ -36,6 +37,7 @@
from .rec_visionlan_head import VisionLANHead
from .table_master_head import TableMasterHead
from .yolov8_head import YOLOv8Head
+from .rec_can_head import CANHead
def build_head(head_name, **kwargs):
diff --git a/mindocr/models/heads/rec_can_head.py b/mindocr/models/heads/rec_can_head.py
new file mode 100644
index 000000000..4ae53921a
--- /dev/null
+++ b/mindocr/models/heads/rec_can_head.py
@@ -0,0 +1,386 @@
+import math
+import mindspore as ms
+from mindspore import nn
+from mindspore import ops
+ms.set_context(mode=ms.PYNATIVE_MODE, pynative_synchronize=True)
+class ChannelAtt(nn.Cell):
+ """Channel Attention of the Counting Module"""
+ def __init__(self, channel, reduction):
+ super().__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.SequentialCell([
+ nn.Dense(channel, channel // reduction),
+ nn.ReLU(),
+ nn.Dense(channel // reduction, channel),
+ nn.Sigmoid()
+ ])
+ def construct(self, x):
+ b, c, _, _ = x.shape
+ y = ops.reshape(self.avg_pool(x), (b, c))
+ y = ops.reshape(self.fc(y), (b, c, 1, 1))
+ return x * y
+class CountingDecoder(nn.Cell):
+ """Single Counting Module"""
+ def __init__(self, in_channel, out_channel, kernel_size):
+ super().__init__()
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.trans_layer = nn.SequentialCell([
+ nn.Conv2d(
+ self.in_channel,
+ 512,
+ kernel_size=kernel_size,
+ pad_mode='pad',
+ padding=kernel_size // 2,
+ has_bias=False,
+ ),
+ nn.BatchNorm2d(512)
+ ])
+ self.channel_att = ChannelAtt(512, 16)
+ self.pred_layer = nn.SequentialCell([
+ nn.Conv2d(
+ 512,
+ self.out_channel,
+ kernel_size=1,
+ has_bias=False,
+ ),
+ nn.Sigmoid()
+ ])
+ def construct(self, x, mask):
+ b, _, h, w = x.shape
+ x = self.trans_layer(x)
+ x = self.channel_att(x)
+ x = self.pred_layer(x)
+ if mask is not None:
+ x = x * mask
+ x = ops.reshape(x, (b, self.out_channel, -1))
+ x1 = ops.sum(x, -1)
+ return x1, ops.reshape(x, (b, self.out_channel, h, w))
+class Attention(nn.Cell):
+ """Attention Module"""
+ def __init__(self, hidden_size, attention_dim):
+ super().__init__()
+ self.hidden = hidden_size
+ self.attention_dim = attention_dim
+ self.hidden_weight = nn.Dense(self.hidden, self.attention_dim)
+ self.attention_conv = nn.Conv2d(
+ 1,
+ 512,
+ kernel_size=11,
+ pad_mode='pad',
+ padding=5,
+ has_bias=False
+ )
+ self.attention_weight = nn.Dense(512, self.attention_dim, has_bias=False)
+ self.alpha_convert = nn.Dense(self.attention_dim, 1)
+ def construct(
+ self, cnn_features, cnn_features_trans, hidden, alpha_sum, image_mask=None
+ ):
+ query = self.hidden_weight(hidden)
+ alpha_sum_trans = self.attention_conv(alpha_sum)
+ coverage_alpha = self.attention_weight(alpha_sum_trans.permute(0, 2, 3, 1))
+ query_expanded = ops.unsqueeze(ops.unsqueeze(query, 1), 2)
+ alpha_score = ops.tanh(
+ query_expanded
+ + coverage_alpha
+ + cnn_features_trans.permute(0, 2, 3, 1)
+ )
+ energy = self.alpha_convert(alpha_score)
+ energy = energy - energy.max()
+ energy_exp = ops.exp(energy.squeeze(-1))
+ if image_mask is not None:
+ energy_exp = energy_exp * image_mask.squeeze(1)
+ alpha = energy_exp / (
+ ops.unsqueeze(ops.unsqueeze(ops.sum(ops.sum(energy_exp, -1), -1), 1), 2) + 1e-10
+ )
+ alpha_sum = ops.unsqueeze(alpha, 1) + alpha_sum
+ context_vector = ops.sum(
+ ops.sum((ops.unsqueeze(alpha, 1) * cnn_features), -1), -1
+ )
+ return context_vector, alpha, alpha_sum
+class PositionEmbeddingSine(nn.Cell):
+ """Position Embedding Sine Module of the Attention Decoder"""
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True when scale is provided")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+ def construct(self, x, mask):
+ y_embed = ops.cumsum(mask, 1, dtype=ms.float32)
+ x_embed = ops.cumsum(mask, 2, dtype=ms.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+ dim_t = ops.arange(self.num_pos_feats, dtype=ms.float32)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+ pos_x = ops.unsqueeze(x_embed, 3) / dim_t
+ pos_y = ops.unsqueeze(y_embed, 3) / dim_t
+ pos_x = ops.flatten(
+ ops.stack(
+ [ops.sin(pos_x[:, :, :, 0::2]), ops.cos(pos_x[:, :, :, 1::2])],
+ axis=4,
+ ),
+ 'C',
+ start_dim=3,
+ )
+ pos_y = ops.flatten(
+ ops.stack(
+ [ops.sin(pos_y[:, :, :, 0::2]), ops.cos(pos_y[:, :, :, 1::2])],
+ axis=4,
+ ),
+ 'C',
+ start_dim=3,
+ )
+ pos = ops.concat([pos_x, pos_y], axis=3)
+ pos = ops.transpose(pos, (0, 3, 1, 2))
+ return pos
+class AttDecoder(nn.Cell):
+ """Attention Decoder Module"""
+ def __init__(
+ self,
+ ratio,
+ is_train,
+ input_size,
+ hidden_size,
+ encoder_out_channel,
+ dropout,
+ dropout_ratio,
+ word_num,
+ counting_decoder_out_channel,
+ attention,
+ ):
+ super().__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.out_channel = encoder_out_channel
+ self.attention_dim = attention["attention_dim"]
+ self.dropout_prob = dropout
+ self.ratio = ratio
+ self.word_num = word_num
+ self.counting_num = counting_decoder_out_channel
+ self.is_train = is_train
+ self.init_weight = nn.Dense(self.out_channel, self.hidden_size)
+ self.embedding = nn.Embedding(self.word_num, self.input_size)
+ self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
+ self.word_attention = Attention(self.hidden_size, self.attention_dim)
+ self.encoder_feature_conv = nn.Conv2d(
+ self.out_channel,
+ self.attention_dim,
+ kernel_size=attention["word_conv_kernel"],
+ pad_mode="pad",
+ padding=attention["word_conv_kernel"] // 2,
+ )
+ self.word_state_weight = nn.Dense(self.hidden_size, self.hidden_size)
+ self.word_embedding_weight = nn.Dense(self.input_size, self.hidden_size)
+ self.word_context_weight = nn.Dense(self.out_channel, self.hidden_size)
+ self.counting_context_weight = nn.Dense(self.counting_num, self.hidden_size)
+ self.word_convert = nn.Dense(self.hidden_size, self.word_num)
+ if dropout:
+ self.dropout = nn.Dropout(p=dropout_ratio)
+ def construct(self, cnn_features, labels, counting_preds, images_mask, is_train=True):
+ if is_train:
+ _, num_steps = labels.shape
+ else:
+ num_steps = 36
+ batch_size, _, height, width = cnn_features.shape
+ images_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
+ word_probs = ops.zeros((batch_size, num_steps, self.word_num))
+ word_alpha_sum = ops.zeros((batch_size, 1, height, width))
+ hidden = self.init_hidden(cnn_features, images_mask)
+ counting_context_weighted = self.counting_context_weight(counting_preds)
+ cnn_features_trans = self.encoder_feature_conv(cnn_features)
+ position_embedding = PositionEmbeddingSine(256, normalize=True)
+ pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])
+ cnn_features_trans = cnn_features_trans + pos
+ word = ops.ones((batch_size, 1), dtype=ms.int64)
+ word = ops.squeeze(word, axis=1)
+ for i in range(num_steps):
+ word_embedding = self.embedding(word)
+ hidden = self.word_input_gru(word_embedding, hidden)
+ word_context_vec, _, word_alpha_sum = self.word_attention(
+ cnn_features,
+ cnn_features_trans,
+ hidden,
+ word_alpha_sum,
+ images_mask
+ )
+ current_state = self.word_state_weight(hidden)
+ word_weight_embedding = self.word_embedding_weight(word_embedding)
+ word_context_weighted = self.word_context_weight(word_context_vec)
+ if self.dropout_prob:
+ word_out_state = self.dropout(
+ current_state
+ + word_weight_embedding
+ + word_context_weighted
+ + counting_context_weighted
+ )
+ else:
+ word_out_state = (
+ current_state
+ + word_weight_embedding
+ + word_context_weighted
+ + counting_context_weighted
+ )
+ word_prob = self.word_convert(word_out_state)
+ word_probs[:, i] = word_prob
+ if self.is_train:
+ word = labels[:, i]
+ else:
+ word = word_prob.argmax(1)
+ word = ops.multiply(
+ word, labels[:, i]
+ )
+ return word_probs
+ def init_hidden(self, features, feature_mask):
+ """Used to initialize the hidden layer"""
+ average = ops.sum(
+ ops.sum(features * feature_mask, dim=-1), dim=-1
+ ) / ops.sum((ops.sum(feature_mask, dim=-1)), dim=-1)
+ average = self.init_weight(average)
+ return ops.tanh(average)
+class CANHead(nn.Cell):
+ r"""The CAN model is an algorithm used to recognize
+ handwritten mathematical formulas.
+ CAN Network is based on
+ `"When Counting Meets HMER: Counting-Aware Network
+ for Handwritten Mathematical Expression Recognition"
+ `_ paper.
+ Args:
+ "in_channels": number of channels for the input feature.
+ "out_channels": number of channels for the output feature.
+ "ratio": the ratio used to downsample the feature map.
+ "attdecoder", the parameters needed to build an AttDecoder:
+ - "is_train": indicates whether the model is in training mode.
+ - "input_size":eEnter the size.
+ - "hidden_size": Hidden layer size.
+ - "encoder_out_channel": number of channels for the encoder output feature.
+ - "dropout": whether to use dropout.
+ - "dropout_ratio": the ratio of dropout.
+ - "word_num": number of words.
+ - "counting_decoder_out_channel": counts the decoder's output channels.
+ - "attention", the parameters needed to build an attention mechanism:
+ - "attention_dim": the dimension of the attention mechanism.
+ - "word_conv_kernel": the size of the lexical convolution kernel.
+ Return:
+ "word_probs": word probability distribution.
+ "counting_preds1": count prediction 1, the number of words
+ predicted by the 3*3 convolution kernel.
+ "counting_preds2": count prediction 2, the number of words
+ predicted by the 5*5 convolution kernel.
+ "counting_preds": the mean predicted by the above two counts.
+ Example:
+ >>> # init a CANHead network
+ >>> in_channels = 684
+ >>> out_channels = 111
+ >>> ratio = 16
+ >>> attdecoder_params = {
+ >>> 'is_train': True,
+ >>> 'input_size': 256,
+ >>> 'hidden_size': 256,
+ >>> 'encoder_out_channel': in_channels,
+ >>> 'dropout': True,
+ >>> 'dropout_ratio': 0.5,
+ >>> 'word_num': 111,
+ >>> 'counting_decoder_out_channel': out_channels,
+ >>> 'attention': {
+ >>> 'attention_dim': 512,
+ >>> 'word_conv_kernel': 1
+ >>> }
+ >>> }
+ >>> model = CANHead(in_channels, out_channels, ratio, attdecoder_params)
+ """
+ def __init__(self, in_channels, out_channels, ratio, attdecoder):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.counting_decoder1 = CountingDecoder(
+ self.in_channels, self.out_channels, 3
+ )
+ self.counting_decoder2 = CountingDecoder(
+ self.in_channels, self.out_channels, 5
+ )
+ self.decoder = AttDecoder(ratio, **attdecoder)
+ self.ratio = ratio
+ def construct(self, x, *args):
+ cnn_features = x
+ images_mask = args[0][0]
+ labels = args[0][1]
+ counting_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
+ counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
+ counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
+ counting_preds = (counting_preds1 + counting_preds2) / 2
+ word_probs = self.decoder(cnn_features, labels, counting_preds, images_mask)
+ return {
+ 'word_probs': word_probs,
+ 'counting_preds': counting_preds,
+ 'counting_preds1': counting_preds1,
+ 'counting_preds2': counting_preds2
+ }
diff --git a/tests/ut/test_can_model.py b/tests/ut/test_can_model.py
new file mode 100644
index 000000000..5eba1c368
--- /dev/null
+++ b/tests/ut/test_can_model.py
@@ -0,0 +1,97 @@
+import sys
+import mindocr
+import mindspore as ms
+from mindocr.models.base_model import BaseModel
+from mindocr.models.backbones import build_backbone
+from mindocr.models.heads import build_head
+from mindspore import ops
+ms.set_context(mode=ms.PYNATIVE_MODE, pynative_synchronize=True)
+if __name__ == "__main__":
+ # model parameter setting
+ model_config = {
+ "backbone": {
+ "name": "rec_can_densenet",
+ "pretrained": False,
+ "growth_rate": 24,
+ "reduction": 0.5,
+ "bottleneck": True,
+ "use_dropout": True,
+ "input_channels": 3,
+ },
+ "neck": {
+ },
+ "head": {
+ "name": "CANHead",
+ "out_channels": 111,
+ "ratio": 16,
+ "attdecoder": {
+ "is_train": True,
+ "input_size": 256,
+ "hidden_size": 256,
+ "encoder_out_channel": 684,
+ "dropout": True,
+ "dropout_ratio": 0.5,
+ "word_num": 111,
+ "counting_decoder_out_channel": 111,
+ "attention": {
+ "attention_dim": 512,
+ "word_conv_kernel": 1,
+ },
+ },
+ },
+ }
+ # test case parameter settings
+ batch_size = 1
+ input_tensor_channel = 3
+ images_mask_channel = 1
+ num_steps = 10
+ word_num = 111
+ out_channels = 111
+ h = 256
+ w = 256
+ input_tensor = ops.randn((batch_size, input_tensor_channel, h, w))
+ images_mask = ops.ones((batch_size, images_mask_channel, h, w))
+ labels = ops.randint(low=0, high=word_num, size=(batch_size, num_steps))
+ # basemodel unit test
+ model_config.pop("neck")
+ model = BaseModel(model_config)
+ hout = model(input_tensor, images_mask, labels)
+ assert hout["word_probs"].shape == (batch_size, num_steps, word_num), "Word probabilities shape is incorrect"
+ assert hout["counting_preds"].shape == (batch_size, out_channels), "Counting predictions shape is incorrect"
+ assert hout["counting_preds1"].shape == (batch_size, out_channels), "Counting predictions 1 shape is incorrect"
+ assert hout["counting_preds2"].shape == (batch_size, out_channels), "Counting predictions 2 shape is incorrect"
+ # build_backbone unit test
+ backbone_name = model_config["backbone"].pop("name")
+ backbone = build_backbone(backbone_name, **model_config["backbone"])
+ bout = backbone(input_tensor)
+ bout_c = backbone.out_channels[-1] #The paper specified 684 features to be extracted
+ bout_h = h/model_config["head"]["ratio"]
+ bout_w = w/model_config["head"]["ratio"]
+ assert bout_c == 684, "bout channel is incorrect"
+ assert bout.shape == (batch_size, bout_c, bout_h, bout_w), "bout shape is incorrect"
+ # build_head unit test
+ head_name = model_config["head"].pop("name")
+ head = build_head(head_name, in_channels=bout_c, **model_config["head"])
+ head_args = ((images_mask, labels))
+ hout = head(bout, head_args)
+ assert hout["word_probs"].shape == (batch_size, num_steps, word_num), "Word probabilities shape is incorrect"
+ assert hout["counting_preds"].shape == (batch_size, out_channels), "Counting predictions shape is incorrect"
+ assert hout["counting_preds1"].shape == (batch_size, out_channels), "Counting predictions 1 shape is incorrect"
+ assert hout["counting_preds2"].shape == (batch_size, out_channels), "Counting predictions 2 shape is incorrect"