Skip to content

Commit

Permalink
fix: analytics issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ashish7515 committed Jan 22, 2025
1 parent 1417de0 commit 859597a
Showing 1 changed file with 37 additions and 32 deletions.
69 changes: 37 additions & 32 deletions cvat/apps/quality_control/quality_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def accumulate(self, other: ComparisonReportAnnotationShapeSummary):
"total_count",
"ds_count",
"gt_count",
"mean_iou",
]:
setattr(self, field, getattr(self, field) + getattr(other, field))

Expand Down Expand Up @@ -2183,25 +2184,25 @@ def _dm_ann_to_ann_id(self, ann):
job_id=source_data_provider.job_id,
)

def match_annotations(self, ds_annotations, gt_annotations):
def _interval_iou(self, interval1, interval2):
start1, end1 = interval1
start2, end2 = interval2

start2 += self._offset
end2 += self._offset

intersection = max(0, min(end1, end2) - max(start1, start2))
union = max(end1, end2) - min(start1, start2)
return intersection / union if union > 0 else 0

def _match_annotations(self, ds_annotations, gt_annotations):
"""
Match annotations between two datasets.
This method should compare annotations based on their start and end times.
"""

def _interval_iou(interval1, interval2):
start1, end1 = interval1
start2, end2 = interval2

start2 += self._offset
end2 += self._offset

intersection = max(0, min(end1, end2) - max(start1, start2))
union = max(end1, end2) - min(start1, start2)
return intersection / union if union > 0 else 0

job_start_time = self._offset - 0.1
job_end_time = job_start_time + self._job_duration + 0.1
job_start_time = self._offset
job_end_time = job_start_time + self._job_duration

# Filter gt_annotations to include only those within the job's time bounds
gt_annotations = [
Expand All @@ -2224,7 +2225,7 @@ def _interval_iou(interval1, interval2):
for ds_ann in ds_annotations:
gt_interval = (gt_ann["points"][0], gt_ann["points"][3])
ds_interval = (ds_ann["points"][0], ds_ann["points"][3])
iou = _interval_iou(gt_interval, ds_interval)
iou = self._interval_iou(gt_interval, ds_interval)

if gt_ann["label_id"] == ds_ann["label_id"]:
if iou >= self.settings.iou_threshold:
Expand All @@ -2250,8 +2251,8 @@ def _interval_iou(interval1, interval2):
# Check if the mismatch pair has acceptable WER and CER
gt_transcript = best_mismatch_pair[0]["transcript"]
ds_transcript = best_mismatch_pair[1]["transcript"]
wer = self.calculate_wer(gt_transcript, ds_transcript)
cer = self.calculate_cer(gt_transcript, ds_transcript)
wer = self._calculate_wer(gt_transcript, ds_transcript)
cer = self._calculate_cer(gt_transcript, ds_transcript)

if wer < self.settings.wer_threshold and cer < self.settings.cer_threshold:
mismatches.append(best_mismatch_pair)
Expand All @@ -2265,7 +2266,7 @@ def _interval_iou(interval1, interval2):

return [matches, mismatches, gt_unmatched, ds_unmatched, pairwise_distances]

def match_attrs(self, ann_a, ann_b):
def _match_attrs(self, ann_a, ann_b):
a_attrs = ann_a["attributes"]
b_attrs = ann_b["attributes"]

Expand All @@ -2285,7 +2286,7 @@ def match_attrs(self, ann_a, ann_b):

return matches, a_unmatched, b_unmatched

def match_extra_parameters(self, gt_ann, ds_ann):
def _match_extra_parameters(self, gt_ann, ds_ann):
parameters = ["gender", "locale", "accent", "emotion", "age"]
matches = []
mismatches = []
Expand All @@ -2297,7 +2298,7 @@ def match_extra_parameters(self, gt_ann, ds_ann):

return matches, mismatches

def calculate_wer(self, gt_transcript, ds_transcript):
def _calculate_wer(self, gt_transcript, ds_transcript):
"""
Calculate the Word Error Rate (WER) between a ground truth transcript and an annotated transcript.
"""
Expand Down Expand Up @@ -2335,7 +2336,7 @@ def calculate_wer(self, gt_transcript, ds_transcript):
wer = d[len(gt_words)][len(ds_words)] / float(len(gt_words))
return wer

def calculate_cer(self, gt_transcript, ds_transcript):
def _calculate_cer(self, gt_transcript, ds_transcript):
"""
Calculate the Character Error Rate (CER) between a ground truth transcript and an annotated transcript.
"""
Expand Down Expand Up @@ -2379,7 +2380,7 @@ def _find_audio_gt_conflicts(self):
gt_frame_list = self._gt_data_provider.job_data._db_job.segment.frames

# Check if any frame in gt_data_frame_array is in ds_data_frame_array
if not (start in gt_frame_list or end in gt_frame_list):
if not (start in gt_frame_list and end in gt_frame_list):
return # we need to compare only intersecting jobs

ds_annotations = self._ds_data_provider.job_annotation.data["shapes"]
Expand All @@ -2389,7 +2390,7 @@ def _find_audio_gt_conflicts(self):

def _process_job(self, ds_annotations, gt_annotations):
job_id = self._job_id
job_results = self.match_annotations(ds_annotations, gt_annotations)
job_results = self._match_annotations(ds_annotations, gt_annotations)
self._job_results.setdefault(job_id, {})

self._generate_job_annotation_conflicts(job_results, gt_annotations, ds_annotations)
Expand All @@ -2401,6 +2402,7 @@ def _generate_job_annotation_conflicts(
job_id = self._job_id
word_error_rate = 0
character_error_rate = 0
iou_sum = 0

matches, mismatches, gt_unmatched, ds_unmatched, _ = job_results

Expand Down Expand Up @@ -2434,10 +2436,14 @@ def _generate_job_annotation_conflicts(
for gt_ann, ds_ann in matches:
gt_transcript = gt_ann["transcript"]
ds_transcript = ds_ann["transcript"]
wer = self.calculate_wer(gt_transcript, ds_transcript)
cer = self.calculate_cer(gt_transcript, ds_transcript)
gt_interval = (gt_ann["points"][0], gt_ann["points"][3])
ds_interval = (ds_ann["points"][0], ds_ann["points"][3])
iou = self._interval_iou(gt_interval, ds_interval)
wer = self._calculate_wer(gt_transcript, ds_transcript)
cer = self._calculate_cer(gt_transcript, ds_transcript)
word_error_rate += wer
character_error_rate += cer
iou_sum += iou
if wer > self.settings.wer_threshold or cer > self.settings.cer_threshold:
conflicts.append(
AnnotationConflict(
Expand All @@ -2454,7 +2460,7 @@ def _generate_job_annotation_conflicts(

if self.settings.compare_attributes:
for gt_ann, ds_ann in matches:
attribute_results = self.match_attrs(gt_ann, ds_ann)
attribute_results = self._match_attrs(gt_ann, ds_ann)
if any(attribute_results[1:]):
conflicts.append(
AnnotationConflict(
Expand All @@ -2469,7 +2475,7 @@ def _generate_job_annotation_conflicts(

if self.settings.compare_extra_parameters:
for gt_ann, ds_ann in matches:
extra_parameter_results = self.match_extra_parameters(gt_ann, ds_ann)
extra_parameter_results = self._match_extra_parameters(gt_ann, ds_ann)
if any(extra_parameter_results[1:]):
conflicts.append(
AnnotationConflict(
Expand Down Expand Up @@ -2517,7 +2523,7 @@ def _generate_job_annotation_conflicts(
total_count=total_shapes_count,
ds_count=ds_shapes_count,
gt_count=gt_shapes_count,
mean_iou=0.7, # need to fix
mean_iou=iou_sum / len(matches) if len(matches) > 0 else 0.0,
),
label=ComparisonReportAnnotationLabelSummary(
valid_count=valid_labels_count,
Expand Down Expand Up @@ -2620,9 +2626,9 @@ def generate_audio_report(self) -> ComparisonReport:
total_count=0,
),
)
mean_ious = []
confusion_matrix_labels, confusion_matrix, _ = self._make_zero_confusion_matrix()

# Fix: No need for this loop as the function itself is only called for single job
for job_id, job_result in self._job_results.items():
intersection_frames.append(job_id)
conflicts += job_result.conflicts
Expand All @@ -2632,7 +2638,6 @@ def generate_audio_report(self) -> ComparisonReport:
annotation_components = deepcopy(job_result.annotation_components)
else:
annotation_components.accumulate(job_result.annotation_components)
mean_ious.append(job_result.annotation_components.shape.mean_iou)

job_result = self._job_results.get(self._job_id, None)
if job_result:
Expand Down Expand Up @@ -2666,7 +2671,7 @@ def generate_audio_report(self) -> ComparisonReport:
total_count=annotation_components.shape.total_count,
ds_count=annotation_components.shape.ds_count,
gt_count=annotation_components.shape.gt_count,
mean_iou=np.mean(mean_ious),
mean_iou=annotation_components.shape.mean_iou,
),
label=ComparisonReportAnnotationLabelSummary(
valid_count=annotation_components.label.valid_count,
Expand Down Expand Up @@ -2856,7 +2861,7 @@ def _compute_reports(self, task_id: int) -> int:
start = job_data_provider.job_data.start
end = job_data_provider.job_data.stop - 1
gt_frame_list = list(gt_job_frames)
if not (start in gt_frame_list or end in gt_frame_list):
if not (start in gt_frame_list and end in gt_frame_list):
offset = 0
ind -= 1

Expand Down

0 comments on commit 859597a

Please sign in to comment.