diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index 2f5c00d544..cad57e8474 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -14,7 +14,7 @@ from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( - HierarchicalCBAMClsHead, + HierarchicalLinearClsHead, LinearClsHead, MultiLabelLinearClsHead, SemiSLLinearClsHead, @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module: return HLabelClassifier( backbone=backbone, - neck=nn.Identity(), - head=HierarchicalCBAMClsHead( - in_channels=backbone.num_features, - **copied_head_config, - ), + neck=GlobalAveragePooling(dim=2), + head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features), multiclass_loss=nn.CrossEntropyLoss(), multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), )