Skip to content

Commit

Permalink
IDWA#114 Create Metrics Functions to Measure Accuracy of our OCR appr…
Browse files Browse the repository at this point in the history
…oaches (#125)

* added new metrics analysis functions

* added examples of files generated

* made edits to json files added more metrics and cleaned metrics files

* linting

* added pre and recall

* Delete OCR/tests/assets/extracted_elements.json

* added new ocr extraction file

* edited order

* edited metrics file

* added ground_truth examples

* edited metrics to reflect new metrics

* split metrics into class

* added metrics tests

* updated lock file

* match to main

* Delete OCR/tests/assets/ground_truth_ltbi_lab.json

* Delete OCR/tests/assets/ground_truth_mumps.json

* Delete OCR/tests/assets/ground_truth_syp.json

* Delete OCR/tests/assets/ground_truth_pertusis.json

* Delete OCR/tests/assets/ocr_elements.json

---------

Co-authored-by: Arindam Kulshi <[email protected]>
  • Loading branch information
arinkulshi-skylight and arinkulshi authored Jul 10, 2024
1 parent 60288d7 commit 579eac4
Show file tree
Hide file tree
Showing 7 changed files with 567 additions and 148 deletions.
17 changes: 17 additions & 0 deletions OCR/ocr/metrics_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from services.metrics_analysis import OCRMetrics
import os


current_script_dir = os.path.dirname(os.path.abspath(__file__))
file_relative_path_ground_truth = "../tests/assets/ltbi_legacy.json"
file_relative_path_ocr = "../tests/assets/ltbi_legacy_ocr.json"
ground_truth_json_path = os.path.join(current_script_dir, file_relative_path_ground_truth)
ocr_json_path = os.path.join(current_script_dir, file_relative_path_ocr)

ocr_metrics = OCRMetrics(ocr_json_path, ground_truth_json_path)
metrics = ocr_metrics.calculate_metrics()
for m in metrics:
print(m)
overall_metrics = ocr_metrics.total_metrics(metrics)
print("Overall Metrics:", overall_metrics)
OCRMetrics.save_metrics_to_csv(metrics, "new.csv")
2 changes: 2 additions & 0 deletions OCR/ocr/services/image_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def __init__(
self.labels = labels

if self.debug is True:
self.debug_folder = "debug_segments"
os.makedirs(self.debug_folder, exist_ok=True)
print(f"raw_image shape: {self.raw_image.shape}")
print(f"segmentation_template shape: {self.segmentation_template.shape}")

Expand Down
124 changes: 124 additions & 0 deletions OCR/ocr/services/metrics_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import json

import csv
import Levenshtein


class OCRMetrics:
"""
A class to calculate and manage OCR metrics.
"""

def __init__(
self, ocr_json_path=None, ground_truth_json_path=None, ocr_json=None, ground_truth_json=None, testMode=False
):
"""
Parameters:
ocr_json (dict): The JSON data extracted from OCR.
ground_truth_json (dict): The JSON data containing ground truth.
"""
if testMode:
self.ocr_json = ocr_json
self.ground_truth_json = ground_truth_json
else:
self.ocr_json = self.load_json_file(ocr_json_path)
self.ground_truth_json = self.load_json_file(ground_truth_json_path)

def load_json_file(self, file_path):
if file_path:
with open(file_path, "r") as file:
data = json.load(file)
return data

@staticmethod
def normalize(text):
if text is None:
return ""
return " ".join(text.strip().lower().split())

@staticmethod
def raw_distance(ocr_text, ground_truth):
return len(ground_truth) - len(ocr_text)

@staticmethod
def hamming_distance(ocr_text, ground_truth):
if len(ocr_text) != len(ground_truth):
raise ValueError("Strings must be of the same length to calculate Hamming distance.")
return Levenshtein.hamming(ocr_text, ground_truth)

@staticmethod
def levenshtein_distance(ocr_text, ground_truth):
return Levenshtein.distance(ocr_text, ground_truth)

def extract_values_from_json(self, json_data):
extracted_values = {}
for item in json_data:
if isinstance(item, dict) and "key" in item and "value" in item:
key = self.normalize(item["key"])
value = self.normalize(item["value"])
extracted_values[key] = value
else:
raise ValueError("Invalid JSON format")
return extracted_values

def calculate_metrics(self):
ocr_values = self.extract_values_from_json(self.ocr_json)
ground_truth_values = self.extract_values_from_json(self.ground_truth_json)

metrics = []
for key in ground_truth_values:
ocr_text = ocr_values.get(key, "")
ground_truth = ground_truth_values[key]
raw_dist = self.raw_distance(ocr_text, ground_truth)
try:
ham_dist = self.hamming_distance(ocr_text, ground_truth)
except ValueError as e:
ham_dist = str(e)
lev_dist = self.levenshtein_distance(ocr_text, ground_truth)
metrics.append(
{
"key": key,
"ocr_text": ocr_text,
"ground_truth": ground_truth,
"raw_distance": raw_dist,
"hamming_distance": ham_dist,
"levenshtein_distance": lev_dist,
}
)
return metrics

@staticmethod
def total_metrics(metrics):
total_raw_distance = sum(item["raw_distance"] for item in metrics if isinstance(item["raw_distance"], int))
total_levenshtein_distance = sum(
item["levenshtein_distance"] for item in metrics if isinstance(item["levenshtein_distance"], int)
)

try:
total_hamming_distance = sum(
item["hamming_distance"] for item in metrics if isinstance(item["hamming_distance"], int)
)
except ValueError:
total_hamming_distance = "N/A due to length mismatch"

ground_truth_length = sum(len(item["ground_truth"]) for item in metrics)
normalized_levenshtein_distance = (
total_levenshtein_distance / ground_truth_length if ground_truth_length else 0
)
accuracy = (1 - normalized_levenshtein_distance) * 100

return {
"total_raw_distance": total_raw_distance,
"total_hamming_distance": total_hamming_distance,
"total_levenshtein_distance": total_levenshtein_distance,
"levenshtein_accuracy": f"{accuracy:.2f}%",
}

@staticmethod
def save_metrics_to_csv(metrics, file_path):
keys = metrics[0].keys()
with open(file_path, "w", newline="") as output_file:
dict_writer = csv.DictWriter(output_file, fieldnames=keys)
dict_writer.writeheader()
dict_writer.writerows(metrics)
Loading

0 comments on commit 579eac4

Please sign in to comment.