Skip to content

Commit

Permalink
Fix label_to_idx for hierarchical classification (#2906)
Browse files Browse the repository at this point in the history
* Fix label_to_idx

* Fix tests

* Fix integration tests

* Fixes from comments

* Remove extra changes

* Update labels and label_idx in .xml

* Change hierarchical_info -> hierarchical_cls_heads_info
  • Loading branch information
GalyaZalesskaya authored Feb 14, 2024
1 parent 1a91b0b commit d46b69a
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 25 deletions.
6 changes: 5 additions & 1 deletion src/otx/algorithms/classification/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ def __init__(

self.model = Model.create_model(model_adapter, "otx_classification", self.configuration, preload=True)

self.converter = ClassificationToAnnotationConverter(self.label_schema)
if self.model.hierarchical:
hierarchical_cls_heads_info = self.model.hierarchical_info["cls_heads_info"]
else:
hierarchical_cls_heads_info = None
self.converter = ClassificationToAnnotationConverter(self.label_schema, hierarchical_cls_heads_info)
self.callback_exceptions: List[Exception] = []
self.model.inference_adapter.set_callback(self._async_callback)

Expand Down
39 changes: 19 additions & 20 deletions src/otx/algorithms/classification/utils/cls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

import json
from operator import itemgetter
from typing import Any, Dict
from typing import Any, Dict, List

from otx.api.entities.label import LabelEntity
from otx.api.entities.label_schema import LabelSchemaEntity
from otx.api.serialization.label_mapper import LabelSchemaMapper

Expand Down Expand Up @@ -51,8 +52,8 @@ def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disabl
for j, group in enumerate(single_label_groups):
class_to_idx[group[0]] = (len(exclusive_groups), j)

all_labels = label_schema.get_labels(include_empty=False)
label_to_idx = {lbl.name: i for i, lbl in enumerate(all_labels)}
# Idx of label corresponds to model output
label_to_idx = {lbl: i for i, lbl in enumerate(class_to_idx.keys())}

mixed_cls_heads_info = {
"num_multiclass_heads": len(exclusive_groups),
Expand Down Expand Up @@ -104,9 +105,13 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
mapi_config[("model_info", "hierarchical")] = str(inference_config["hierarchical"])
mapi_config[("model_info", "output_raw_scores")] = str(True)

label_entities = label_schema.get_labels(include_empty=False)
if inference_config["hierarchical"]:
label_entities = get_hierarchical_label_list(inference_config["multihead_class_info"], label_entities)

all_labels = ""
all_label_ids = ""
for lbl in label_schema.get_labels(include_empty=False):
for lbl in label_entities:
all_labels += lbl.name.replace(" ", "_") + " "
all_label_ids += f"{lbl.id_} "

Expand All @@ -123,22 +128,16 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
return mapi_config


def get_hierarchical_label_list(hierarchical_info, labels):
def get_hierarchical_label_list(hierarchical_cls_heads_info: Dict, labels: List) -> List[LabelEntity]:
"""Return hierarchical labels list which is adjusted to model outputs classes."""

# Create the list of Label Entities (took from "labels")
# corresponding to names and order in "label_to_idx"
label_to_idx = hierarchical_cls_heads_info["label_to_idx"]
hierarchical_labels = []
for head_idx in range(hierarchical_info["num_multiclass_heads"]):
logits_begin, logits_end = hierarchical_info["head_idx_to_logits_range"][str(head_idx)]
for logit in range(0, logits_end - logits_begin):
label_str = hierarchical_info["all_groups"][head_idx][logit]
label_idx = hierarchical_info["label_to_idx"][label_str]
hierarchical_labels.append(labels[label_idx])

if hierarchical_info["num_multilabel_classes"]:
logits_begin = hierarchical_info["num_single_label_classes"]
logits_end = len(labels)
for logit_idx, logit in enumerate(range(0, logits_end - logits_begin)):
label_str_idx = hierarchical_info["num_multiclass_heads"] + logit_idx
label_str = hierarchical_info["all_groups"][label_str_idx][0]
label_idx = hierarchical_info["label_to_idx"][label_str]
hierarchical_labels.append(labels[label_idx])
for label_str, _ in label_to_idx.items():
for label_entity in labels:
if label_entity.name == label_str:
hierarchical_labels.append(label_entity)
break
return hierarchical_labels
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from openvino.model_api.models import utils
from openvino.model_api.models.utils import AnomalyResult

from otx.algorithms.classification.utils import get_hierarchical_label_list
from otx.api.entities.annotation import (
Annotation,
AnnotationSceneEntity,
Expand Down Expand Up @@ -171,7 +172,11 @@ def create_converter(
elif converter_type == Domain.SEGMENTATION:
converter = SegmentationToAnnotationConverter(labels)
elif converter_type == Domain.CLASSIFICATION:
converter = ClassificationToAnnotationConverter(labels)
if configuration is not None and configuration.get("hierarchical", False):
hierarchical_cls_heads_info = configuration["multihead_class_info"]
else:
hierarchical_cls_heads_info = None
converter = ClassificationToAnnotationConverter(labels, hierarchical_cls_heads_info)
elif converter_type == Domain.ANOMALY_CLASSIFICATION:
converter = AnomalyClassificationToAnnotationConverter(labels)
elif converter_type == Domain.ANOMALY_DETECTION:
Expand Down Expand Up @@ -268,9 +273,10 @@ class ClassificationToAnnotationConverter(IPredictionToAnnotationConverter):
Args:
labels (LabelSchemaEntity): Label Schema containing the label info of the task
hierarchical_cls_heads_info (Dict): Info from model.hierarchical_info["cls_heads_info"]
"""

def __init__(self, label_schema: LabelSchemaEntity):
def __init__(self, label_schema: LabelSchemaEntity, hierarchical_cls_heads_info: Optional[Dict] = None):
if len(label_schema.get_labels(False)) == 1:
self.labels = label_schema.get_labels(include_empty=True)
else:
Expand All @@ -284,6 +290,9 @@ def __init__(self, label_schema: LabelSchemaEntity):

self.label_schema = label_schema

if self.hierarchical:
self.labels = get_hierarchical_label_list(hierarchical_cls_heads_info, self.labels)

def convert_to_annotation(
self, predictions: List[Tuple[int, float]], metadata: Optional[Dict] = None
) -> AnnotationSceneEntity:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_post_process(self):
}
fake_metadata = {"original_shape": (254, 320, 3), "resized_shape": (224, 224, 3)}
self.cls_ov_inferencer.model.postprocess.return_value = [[0, 0.87], [1, 0.13]]
self.cls_ov_inferencer.model.hierarchical = False
returned_value = self.cls_ov_inferencer.post_process(fake_prediction, fake_metadata)

assert len(returned_value.annotations[0].get_labels()) > 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,10 @@ def test_classification_to_annotation_init(self):
labels=other_non_empty_labels,
)
label_schema = LabelSchemaEntity(label_groups=[label_group, other_label_group])
converter = ClassificationToAnnotationConverter(label_schema=label_schema)
hierarchical_cls_heads_info = {"label_to_idx": {label_0_1.name: 0, label_0_1_1.name: 1, label_0_2.name: 2}}
converter = ClassificationToAnnotationConverter(
label_schema=label_schema, hierarchical_cls_heads_info=hierarchical_cls_heads_info
)
assert not converter.empty_label
assert converter.label_schema == label_schema
assert converter.hierarchical
Expand Down Expand Up @@ -840,7 +843,10 @@ def check_annotation(actual_annotation: Annotation, expected_labels: list):
label_schema = LabelSchemaEntity(label_groups=[label_group, other_label_group])

label_schema.add_child(parent=label_0_1, child=label_0_1_1)
converter = ClassificationToAnnotationConverter(label_schema=label_schema)
hierarchical_cls_heads_info = {"label_to_idx": {label_0_1.name: 0, label_0_1_1.name: 1, label_0_2.name: 2}}
converter = ClassificationToAnnotationConverter(
label_schema=label_schema, hierarchical_cls_heads_info=hierarchical_cls_heads_info
)
predictions = [(2, 0.9), (1, 0.8)]
predictions_to_annotations = converter.convert_to_annotation(predictions)
check_annotation_scene(annotation_scene=predictions_to_annotations, expected_length=1)
Expand Down

0 comments on commit d46b69a

Please sign in to comment.