diff --git a/src/otx/algorithms/classification/adapters/openvino/task.py b/src/otx/algorithms/classification/adapters/openvino/task.py index ac4ed4b874e..33c90e19ae8 100644 --- a/src/otx/algorithms/classification/adapters/openvino/task.py +++ b/src/otx/algorithms/classification/adapters/openvino/task.py @@ -126,7 +126,11 @@ def __init__( self.model = Model.create_model(model_adapter, "otx_classification", self.configuration, preload=True) - self.converter = ClassificationToAnnotationConverter(self.label_schema) + if self.model.hierarchical: + hierarchical_cls_heads_info = self.model.hierarchical_info["cls_heads_info"] + else: + hierarchical_cls_heads_info = None + self.converter = ClassificationToAnnotationConverter(self.label_schema, hierarchical_cls_heads_info) self.callback_exceptions: List[Exception] = [] self.model.inference_adapter.set_callback(self._async_callback) diff --git a/src/otx/algorithms/classification/utils/cls_utils.py b/src/otx/algorithms/classification/utils/cls_utils.py index 968586a7d5a..023ec8c415a 100644 --- a/src/otx/algorithms/classification/utils/cls_utils.py +++ b/src/otx/algorithms/classification/utils/cls_utils.py @@ -18,8 +18,9 @@ import json from operator import itemgetter -from typing import Any, Dict +from typing import Any, Dict, List +from otx.api.entities.label import LabelEntity from otx.api.entities.label_schema import LabelSchemaEntity from otx.api.serialization.label_mapper import LabelSchemaMapper @@ -51,8 +52,8 @@ def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disabl for j, group in enumerate(single_label_groups): class_to_idx[group[0]] = (len(exclusive_groups), j) - all_labels = label_schema.get_labels(include_empty=False) - label_to_idx = {lbl.name: i for i, lbl in enumerate(all_labels)} + # Idx of label corresponds to model output + label_to_idx = {lbl: i for i, lbl in enumerate(class_to_idx.keys())} mixed_cls_heads_info = { "num_multiclass_heads": len(exclusive_groups), @@ -104,9 +105,13 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c mapi_config[("model_info", "hierarchical")] = str(inference_config["hierarchical"]) mapi_config[("model_info", "output_raw_scores")] = str(True) + label_entities = label_schema.get_labels(include_empty=False) + if inference_config["hierarchical"]: + label_entities = get_hierarchical_label_list(inference_config["multihead_class_info"], label_entities) + all_labels = "" all_label_ids = "" - for lbl in label_schema.get_labels(include_empty=False): + for lbl in label_entities: all_labels += lbl.name.replace(" ", "_") + " " all_label_ids += f"{lbl.id_} " @@ -123,22 +128,16 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c return mapi_config -def get_hierarchical_label_list(hierarchical_info, labels): +def get_hierarchical_label_list(hierarchical_cls_heads_info: Dict, labels: List) -> List[LabelEntity]: """Return hierarchical labels list which is adjusted to model outputs classes.""" + + # Create the list of Label Entities (took from "labels") + # corresponding to names and order in "label_to_idx" + label_to_idx = hierarchical_cls_heads_info["label_to_idx"] hierarchical_labels = [] - for head_idx in range(hierarchical_info["num_multiclass_heads"]): - logits_begin, logits_end = hierarchical_info["head_idx_to_logits_range"][str(head_idx)] - for logit in range(0, logits_end - logits_begin): - label_str = hierarchical_info["all_groups"][head_idx][logit] - label_idx = hierarchical_info["label_to_idx"][label_str] - hierarchical_labels.append(labels[label_idx]) - - if hierarchical_info["num_multilabel_classes"]: - logits_begin = hierarchical_info["num_single_label_classes"] - logits_end = len(labels) - for logit_idx, logit in enumerate(range(0, logits_end - logits_begin)): - label_str_idx = hierarchical_info["num_multiclass_heads"] + logit_idx - label_str = hierarchical_info["all_groups"][label_str_idx][0] - label_idx = hierarchical_info["label_to_idx"][label_str] - hierarchical_labels.append(labels[label_idx]) + for label_str, _ in label_to_idx.items(): + for label_entity in labels: + if label_entity.name == label_str: + hierarchical_labels.append(label_entity) + break return hierarchical_labels diff --git a/src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py b/src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py index 22a60455231..8f90824b1a4 100644 --- a/src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py +++ b/src/otx/api/usecases/exportable_code/prediction_to_annotation_converter.py @@ -12,6 +12,7 @@ from openvino.model_api.models import utils from openvino.model_api.models.utils import AnomalyResult +from otx.algorithms.classification.utils import get_hierarchical_label_list from otx.api.entities.annotation import ( Annotation, AnnotationSceneEntity, @@ -171,7 +172,11 @@ def create_converter( elif converter_type == Domain.SEGMENTATION: converter = SegmentationToAnnotationConverter(labels) elif converter_type == Domain.CLASSIFICATION: - converter = ClassificationToAnnotationConverter(labels) + if configuration is not None and configuration.get("hierarchical", False): + hierarchical_cls_heads_info = configuration["multihead_class_info"] + else: + hierarchical_cls_heads_info = None + converter = ClassificationToAnnotationConverter(labels, hierarchical_cls_heads_info) elif converter_type == Domain.ANOMALY_CLASSIFICATION: converter = AnomalyClassificationToAnnotationConverter(labels) elif converter_type == Domain.ANOMALY_DETECTION: @@ -268,9 +273,10 @@ class ClassificationToAnnotationConverter(IPredictionToAnnotationConverter): Args: labels (LabelSchemaEntity): Label Schema containing the label info of the task + hierarchical_cls_heads_info (Dict): Info from model.hierarchical_info["cls_heads_info"] """ - def __init__(self, label_schema: LabelSchemaEntity): + def __init__(self, label_schema: LabelSchemaEntity, hierarchical_cls_heads_info: Optional[Dict] = None): if len(label_schema.get_labels(False)) == 1: self.labels = label_schema.get_labels(include_empty=True) else: @@ -284,6 +290,9 @@ def __init__(self, label_schema: LabelSchemaEntity): self.label_schema = label_schema + if self.hierarchical: + self.labels = get_hierarchical_label_list(hierarchical_cls_heads_info, self.labels) + def convert_to_annotation( self, predictions: List[Tuple[int, float]], metadata: Optional[Dict] = None ) -> AnnotationSceneEntity: diff --git a/tests/unit/algorithms/classification/tasks/test_classification_openvino_task.py b/tests/unit/algorithms/classification/tasks/test_classification_openvino_task.py index 504265d4fd9..98393a96b55 100644 --- a/tests/unit/algorithms/classification/tasks/test_classification_openvino_task.py +++ b/tests/unit/algorithms/classification/tasks/test_classification_openvino_task.py @@ -81,6 +81,7 @@ def test_post_process(self): } fake_metadata = {"original_shape": (254, 320, 3), "resized_shape": (224, 224, 3)} self.cls_ov_inferencer.model.postprocess.return_value = [[0, 0.87], [1, 0.13]] + self.cls_ov_inferencer.model.hierarchical = False returned_value = self.cls_ov_inferencer.post_process(fake_prediction, fake_metadata) assert len(returned_value.annotations[0].get_labels()) > 0 diff --git a/tests/unit/api/usecases/exportable_code/test_prediction_to_annotation_converter.py b/tests/unit/api/usecases/exportable_code/test_prediction_to_annotation_converter.py index fdc81c94101..9144f593c41 100644 --- a/tests/unit/api/usecases/exportable_code/test_prediction_to_annotation_converter.py +++ b/tests/unit/api/usecases/exportable_code/test_prediction_to_annotation_converter.py @@ -749,7 +749,10 @@ def test_classification_to_annotation_init(self): labels=other_non_empty_labels, ) label_schema = LabelSchemaEntity(label_groups=[label_group, other_label_group]) - converter = ClassificationToAnnotationConverter(label_schema=label_schema) + hierarchical_cls_heads_info = {"label_to_idx": {label_0_1.name: 0, label_0_1_1.name: 1, label_0_2.name: 2}} + converter = ClassificationToAnnotationConverter( + label_schema=label_schema, hierarchical_cls_heads_info=hierarchical_cls_heads_info + ) assert not converter.empty_label assert converter.label_schema == label_schema assert converter.hierarchical @@ -840,7 +843,10 @@ def check_annotation(actual_annotation: Annotation, expected_labels: list): label_schema = LabelSchemaEntity(label_groups=[label_group, other_label_group]) label_schema.add_child(parent=label_0_1, child=label_0_1_1) - converter = ClassificationToAnnotationConverter(label_schema=label_schema) + hierarchical_cls_heads_info = {"label_to_idx": {label_0_1.name: 0, label_0_1_1.name: 1, label_0_2.name: 2}} + converter = ClassificationToAnnotationConverter( + label_schema=label_schema, hierarchical_cls_heads_info=hierarchical_cls_heads_info + ) predictions = [(2, 0.9), (1, 0.8)] predictions_to_annotations = converter.convert_to_annotation(predictions) check_annotation_scene(annotation_scene=predictions_to_annotations, expected_length=1)