diff --git a/discrete-binarization/README.md b/discrete-binarization/README.md new file mode 100644 index 0000000000..ca73be0767 --- /dev/null +++ b/discrete-binarization/README.md @@ -0,0 +1,10 @@ +## Download the dataset from Kaggle + +```shell +kaggle datasets download -d torkiasalem/text-detection-icdar2015 +unzip -qq text-detection-icdar2015.zip +``` + +## Reference + +- https://github.com/MhLiao/DB diff --git a/discrete-binarization/dataset.py b/discrete-binarization/dataset.py new file mode 100644 index 0000000000..32658607e5 --- /dev/null +++ b/discrete-binarization/dataset.py @@ -0,0 +1,89 @@ +import json +from pathlib import Path + +import cv2 +import keras +import numpy as np +from label_generator import generate_text_probability_map +from label_generator import generate_threshold_label + + +def read_label_file(file_path): + with open(file_path, "r") as f: + data = f.readlines() + return data + + +def decode_image(image_path): + try: + image = cv2.imread(image_path) + if image.shape[-1] == 1: + image = np.dstack([image, image, image]) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image + except: + return None + + +def pad_polygons(polygons): + max_num_points = 0 + for polygon in polygons: + if len(polygon) > max_num_points: + max_num_points = len(polygon) + padded_polygons = list() + for polygon in polygons: + polygon = polygon + [polygon[-1]] * (max_num_points - len(polygon)) + padded_polygons.append(polygon) + return padded_polygons + + +def decode_label(label): + label = json.loads(label) + polygons, texts, ignore_flags = list(), list(), list() + for info in label: + text = info["transcription"] + if text in ["*", "###"]: + ignore_flags.append(True) + else: + ignore_flags.append(False) + polygon = info["points"] + polygons.append(polygon) + texts.append(text) + polygons = pad_polygons(polygons) + return polygons, texts, ignore_flags + + +class ICDARPyDataset(keras.utils.PyDataset): + def __init__(self, image_dir, file_path, **kwargs): + super().__init__(**kwargs) + self.image_dir = image_dir + self.file_path = file_path + self.data_lines = read_label_file(file_path) + + def __len__(self): + return len(self.data_lines) + + def __getitem__(self, idx): + # 1. Grab a single line from the data_lines list + line = self.data_lines[idx] + # 2. Split the line into image_path and label + image_path, label = line.strip().split("\t") + image_path = str(Path(self.image_dir) / Path(image_path).name) + # 3. Decode the image + image = decode_image(image_path) + if image is None: + return dict() + # 4. Decode the labels + polygons, texts, ignore_flags = decode_label(label) + if len(polygons) == 0: + return dict() + else: + data = { + "image": image, + "polygons": np.array(polygons), + "texts": np.array(texts), + "ignore_flags": np.array(ignore_flags), + } + generate_text_probability_map(data) + generate_threshold_label(data) + return data diff --git a/discrete-binarization/label_generator.py b/discrete-binarization/label_generator.py new file mode 100644 index 0000000000..2b05602150 --- /dev/null +++ b/discrete-binarization/label_generator.py @@ -0,0 +1,279 @@ +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + + +###################################################################### +# Lable Generation for Threshold Maps +###################################################################### +def calculate_distance(grid_x, grid_y, segment_start, segment_end): + """ + Calculate the perpendicular distance from each point in a grid to a line segment. + + Args: + - grid_x (np.ndarray): The x-coordinates of the grid. + - grid_y (np.ndarray): The y-coordinates of the grid. + - segment_start (tuple): The starting point of the line segment. + - segment_end (tuple): The ending point of the line segment. + + Returns: + - np.ndarray: An array of distances. + """ + # Calculating the square of distances from the grid points to the line segment points + distance_to_start_sq = np.square(grid_x - segment_start[0]) + np.square( + grid_y - segment_start[1] + ) + distance_to_end_sq = np.square(grid_x - segment_end[0]) + np.square( + grid_y - segment_end[1] + ) + + # Calculating the square of the distance between the line segment points + segment_length_sq = np.square( + segment_start[0] - segment_end[0] + ) + np.square(segment_start[1] - segment_end[1]) + + # Calculating cosine of the angle at the grid points + cosine_angle = ( + segment_length_sq - distance_to_start_sq - distance_to_end_sq + ) / (2 * np.sqrt(distance_to_start_sq * distance_to_end_sq) + 1e-6) + + # Calculating the square of sine of the angle + sine_square = 1 - np.square(cosine_angle) + sine_square = np.nan_to_num(sine_square) + + # Calculating the perpendicular distance + perpendicular_distance = np.sqrt( + distance_to_start_sq + * distance_to_end_sq + * sine_square + / segment_length_sq + ) + perpendicular_distance[cosine_angle < 0] = np.sqrt( + np.fmin(distance_to_start_sq, distance_to_end_sq) + )[cosine_angle < 0] + + return perpendicular_distance + + +def draw_border_map_on_canvas(polygon_points, canvas, mask, shrink_ratio): + """ + Draw a border map on a canvas based on a polygon and a shrink ratio. + + Args: + - polygon_points (list): Points of the polygon. + - canvas (np.ndarray): The canvas to draw on. + - mask (np.ndarray): The mask to use for drawing. + - shrink_ratio (float): The ratio to shrink the polygon by. + """ + polygon = np.array(polygon_points) + polygon_shape = Polygon(polygon) + + # Exit if the polygon area is not positive + if polygon_shape.area <= 0: + return + + # Calculating the distance to shrink the polygon by + shrink_distance = ( + polygon_shape.area + * (1 - np.power(shrink_ratio, 2)) + / polygon_shape.length + ) + + # Creating the shrunk polygon + subject_polygon = [tuple(l) for l in polygon] + polygon_padder = pyclipper.PyclipperOffset() + polygon_padder.AddPath( + subject_polygon, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON + ) + shrunk_polygon = np.array(polygon_padder.Execute(shrink_distance)[0]) + + # Filling the polygon in the mask + cv2.fillPoly(mask, [shrunk_polygon.astype(np.int32)], 1.0) + + # Calculating bounding box for the shrunk polygon + xmin, xmax = shrunk_polygon[:, 0].min(), shrunk_polygon[:, 0].max() + ymin, ymax = shrunk_polygon[:, 1].min(), shrunk_polygon[:, 1].max() + width, height = xmax - xmin + 1, ymax - ymin + 1 + + # Adjusting polygon points relative to the bounding box + adjusted_polygon = polygon.copy() + adjusted_polygon[:, 0] -= xmin + adjusted_polygon[:, 1] -= ymin + + # Creating grids for distance calculation + grid_x = np.broadcast_to( + np.linspace(0, width - 1, num=width).reshape(1, width), (height, width) + ) + grid_y = np.broadcast_to( + np.linspace(0, height - 1, num=height).reshape(height, 1), + (height, width), + ) + + # Initializing the distance map + distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32) + + # Calculating the distance to each edge of the polygon + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = calculate_distance( + grid_x, grid_y, adjusted_polygon[i], adjusted_polygon[j] + ) + distance_map[i] = np.clip(absolute_distance / shrink_distance, 0, 1) + + # Taking the minimum distance to any edge + distance_map = distance_map.min(axis=0) + + # Validating and adjusting the bounding box coordinates + xmin_valid, xmax_valid = min(max(0, xmin), canvas.shape[1] - 1), min( + max(0, xmax), canvas.shape[1] - 1 + ) + ymin_valid, ymax_valid = min(max(0, ymin), canvas.shape[0] - 1), min( + max(0, ymax), canvas.shape[0] - 1 + ) + + # Applying the distance map to the canvas + canvas_slice = canvas[ + ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1 + ] + distance_map_slice = distance_map[ + ymin_valid - ymin : ymax_valid - ymax + height, + xmin_valid - xmin : xmax_valid - xmax + width, + ] + canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax( + 1 - distance_map_slice, canvas_slice + ) + + +def generate_threshold_label( + data, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7 +): + """ + Generate a threshold label for an image based on given polygons and ignore tags. + + Args: + - data (dict): A dictionary containing 'image', 'polygons', and 'ignore_tags'. + - shrink_ratio (float): Ratio to shrink polygons by. + - thresh_min (float): Minimum threshold value. + - thresh_max (float): Maximum threshold value. + """ + image = data["image"] + polygons = data["polygons"] + ignore_flags = data["ignore_flags"] + + # Initializing canvas and mask + canvas = np.zeros(image.shape[:2], dtype=np.float32) + mask = np.zeros(image.shape[:2], dtype=np.float32) + + # Drawing each polygon on the canvas + for polygon, ignore_flag in zip(polygons, ignore_flags): + if not ignore_flag: + draw_border_map_on_canvas( + polygon, canvas, mask=mask, shrink_ratio=shrink_ratio + ) + + # Adjusting the canvas based on thresholds + canvas = canvas * (thresh_max - thresh_min) + thresh_min + + # Updating the data dictionary + data["threshold_map"] = canvas + data["threshold_mask"] = mask + + +###################################################################### +# Label Generation for Probability and Binary Maps +###################################################################### +def adjust_polygons_within_image_bounds( + polygons, ignore_flags, image_height, image_width +): + """ + Adjusts polygons to ensure they fit within the image bounds and calculates + their area to update ignore flags for invalid polygons. + + Parameters: + - polygons (np.ndarray): Array of polygons coordinates. + - ignore_flags (list): List of boolean flags to ignore certain polygons. + - image_height (int): Height of the image. + - image_width (int): Width of the image. + + Returns: + - Tuple: (Adjusted polygons, Updated ignore flags) + """ + for i, polygon in enumerate(polygons): + # Clip polygon coordinates to be within image dimensions + polygons[i] = np.clip( + polygon, [0, 0], [image_width - 1, image_height - 1] + ) + + # If polygon area is too small, mark it as ignored + if Polygon(polygon).area < 1: + ignore_flags[i] = True + polygons[i] = polygons[i][::-1] # Reverse polygon coordinates + + return polygons, ignore_flags + + +def shrink_polygon(polygon, shrink_ratio): + """ + Shrinks a given polygon according to the specified shrink ratio. + + Parameters: + - polygon (Polygon): A Shapely Polygon object. + - shrink_ratio (float): Ratio to shrink the polygon. + + Returns: + - Polygon: The shrunk polygon. + """ + if polygon.area == 0: + return Polygon() + + # Calculate shrinking distance based on area and perimeter + shrink_distance = polygon.area * (1 - shrink_ratio**2) / polygon.length + return polygon.buffer(-shrink_distance, join_style=2) + + +def generate_text_probability_map(data, min_text_size=8, shrink_ratio=0.4): + """ + Generates a probability map for text detection in an image by processing + the provided polygons. + + Parameters: + - data (dict): Contains the image, polygons, and ignore tags. + - min_text_size (int): Minimum size of text to be detected. + - shrink_ratio (float): Ratio for shrinking the polygons. + + Modifies: + - data (dict): Adds the shrink_map and shrink_mask to the data. + """ + image = data["image"] + polygons = data["polygons"] + ignore_flags = data["ignore_flags"] + image_height, image_width = image.shape[:2] + + polygons, ignore_flags = adjust_polygons_within_image_bounds( + polygons, ignore_flags, image_height, image_width + ) + text_region_map = np.zeros((image_height, image_width), dtype=np.float32) + mask_map = np.ones((image_height, image_width), dtype=np.float32) + + for i, polygon in enumerate(polygons): + # Check for ignored polygons or those too small to consider + if ( + ignore_flags[i] + or min(polygon[:, 1].ptp(), polygon[:, 0].ptp()) < min_text_size + ): + cv2.fillPoly(mask_map, [polygon.astype(np.int32)], 0) + continue + + shrunk_polygon = shrink_polygon(Polygon(polygon), shrink_ratio) + if shrunk_polygon.is_empty: + cv2.fillPoly(mask_map, [polygon.astype(np.int32)], 0) + else: + cv2.fillPoly( + text_region_map, + [np.array(shrunk_polygon.exterior.coords).astype(np.int32)], + 1, + ) + + data["shrink_map"] = text_region_map + data["shrink_mask"] = mask_map diff --git a/discrete-binarization/losses.py b/discrete-binarization/losses.py new file mode 100644 index 0000000000..7d3f894af3 --- /dev/null +++ b/discrete-binarization/losses.py @@ -0,0 +1,103 @@ +import keras +from keras import ops + + +class DiceLoss(keras.losses.Loss): + def __init__(self, eps=1e-6, **kwargs): + super().__init__(**kwargs) + self.eps = eps + + def call(self, y_true, y_pred, mask, weights=None): + if weights is not None: + mask = weights * mask + intersection = ops.sum((y_pred * y_true * mask)) + union = ops.sum((y_pred * mask)) + ops.sum(y_true * mask) + self.eps + loss = 1 - 2.0 * intersection / union + return loss + + +class MaskL1Loss(keras.losses.Loss): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, y_true, y_pred, mask): + mask_sum = ops.sum(mask) + loss = ops.cond( + mask_sum == 0, + lambda: mask_sum, + lambda: ops.sum(ops.absolute(y_pred - y_true) * mask) / mask_sum, + ) + return loss + + +class BalanceCrossEntropyLoss(keras.losses.Loss): + def __init__(self, negative_ratio=3.0, eps=1e-6, **kwargs): + super().__init__(**kwargs) + self.negative_ratio = negative_ratio + self.eps = eps + + def forward(self, y_true, y_pred, mask, return_origin=False): + positive = ops.cast(y_true * mask, "uint8") + negative = ops.cast(((1 - y_true) * mask), "uint8") + positive_count = int(ops.sum(ops.cast(positive, "float32"))) + negative_count = ops.min( + int(ops.sum(ops.cast(negative, "float32"))), + int(positive_count * self.negative_ratio), + ) + loss = keras.losses.BinaryCrossentropy( + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction=None, + )(y_true=y_true, y_pred=y_pred) + positive_loss = loss * ops.cast(positive, "float32") + negative_loss = loss * ops.cast(negative, "float32") + negative_loss, _ = ops.topk( + ops.reshape(negative_loss, (-1)), negative_count + ) + + balance_loss = (ops.sum(positive_loss) + ops.sum(negative_loss)) / ( + positive_count + negative_count + self.eps + ) + + if return_origin: + return balance_loss, loss + return balance_loss + + +class DBLoss(keras.losses.Loss): + def __init__(self, eps=1e-6, l1_scale=10, bce_scale=5, **kwargs): + super().__init__() + self.dice_loss = DiceLoss(eps=eps) + self.l1_loss = MaskL1Loss() + self.bce_loss = BalanceCrossEntropyLoss() + + self.l1_scale = l1_scale + self.bce_scale = bce_scale + + def call(self, y_true, y_pred, mask): + p_map_pred, t_map_pred, b_map_pred = ops.split(y_pred, 3, axis=-1) + shrink_map, thresh_map = y_true + shrink_mask, thresh_mask = mask + + bce_loss, bce_map = self.bce_loss( + y_true=shrink_map, + y_pred=p_map_pred, + mask=shrink_mask, + return_origin=True, + ) + l1_loss = self.l1_loss( + y_true=thresh_map, + y_pred=t_map_pred, + mask=thresh_mask, + ) + bce_map = (bce_map - ops.minimum(bce_map)) / ( + ops.maximum(bce_map) - ops.maximum(bce_map) + ) + dice_loss = self.dice_loss( + y_true=shrink_map, + y_pred=b_map_pred, + weights=bce_map + 1, + ) + loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale + return loss diff --git a/discrete-binarization/model.py b/discrete-binarization/model.py new file mode 100644 index 0000000000..8ba1cfb417 --- /dev/null +++ b/discrete-binarization/model.py @@ -0,0 +1,119 @@ +import math + +import keras +from keras import layers +from keras import ops + +import keras_cv + + +def get_backbone_model(input_shape=(640, 640, 3)): + resnet = keras_cv.models.ResNet50Backbone.from_preset( + "resnet50_imagenet", input_shape=input_shape, load_weights=False + ) + levels = ["P2", "P3", "P4", "P5"] + layer_names = [resnet.pyramid_level_inputs[level] for level in levels] + items = zip(levels, layer_names) + outputs = {key: resnet.get_layer(name).output for key, name in items} + backbone = keras.Model(resnet.inputs, outputs=outputs) + return backbone + + +def FPNModel(out_channels, **kwargs): + def apply(inputs): + # c2, c3, c4, c5 = inputs + c2 = inputs["P2"] + c3 = inputs["P3"] + c4 = inputs["P4"] + c5 = inputs["P5"] + in2 = layers.Conv2D(out_channels, kernel_size=1, use_bias=False)(c2) + in3 = layers.Conv2D(out_channels, kernel_size=1, use_bias=False)(c3) + in4 = layers.Conv2D(out_channels, kernel_size=1, use_bias=False)(c4) + in5 = layers.Conv2D(out_channels, kernel_size=1, use_bias=False)(c5) + out4 = layers.Add()([layers.UpSampling2D()(in5), in4]) + out3 = layers.Add()([layers.UpSampling2D()(out4), in3]) + out2 = layers.Add()([layers.UpSampling2D()(out3), in2]) + p5 = layers.Conv2D( + out_channels // 4, kernel_size=3, padding="same", use_bias=False + )(in5) + p4 = layers.Conv2D( + out_channels // 4, kernel_size=3, padding="same", use_bias=False + )(out4) + p3 = layers.Conv2D( + out_channels // 4, kernel_size=3, padding="same", use_bias=False + )(out3) + p2 = layers.Conv2D( + out_channels // 4, kernel_size=3, padding="same", use_bias=False + )(out2) + p5 = layers.UpSampling2D((8, 8))(p5) + p4 = layers.UpSampling2D((4, 4))(p4) + p3 = layers.UpSampling2D((2, 2))(p3) + + fused = layers.Concatenate(axis=-1)([p5, p4, p3, p2]) + return fused + + return apply + + +def Head(in_channels, kernel_list=[3, 2, 2], **kwargs): + def apply(inputs): + x = layers.Conv2D( + in_channels // 4, + kernel_size=kernel_list[0], + padding="same", + use_bias=False, + )(inputs) + x = layers.BatchNormalization( + beta_initializer=keras.initializers.Constant(1e-4), + gamma_initializer=keras.initializers.Constant(1.0), + )(x) + x = layers.ReLU()(x) + x = layers.Conv2DTranspose( + in_channels // 4, + kernel_size=kernel_list[1], + strides=2, + padding="valid", + bias_initializer=keras.initializers.RandomUniform( + minval=-1.0 / math.sqrt(in_channels // 4 * 1.0), + maxval=1.0 / math.sqrt(in_channels // 4 * 1.0), + ), + )(x) + x = layers.BatchNormalization( + beta_initializer=keras.initializers.Constant(1e-4), + gamma_initializer=keras.initializers.Constant(1.0), + )(x) + x = layers.ReLU()(x) + x = layers.Conv2DTranspose( + 1, + kernel_size=kernel_list[2], + strides=2, + padding="valid", + activation="sigmoid", + bias_initializer=keras.initializers.RandomUniform( + minval=-1.0 / math.sqrt(in_channels // 4 * 1.0), + maxval=1.0 / math.sqrt(in_channels // 4 * 1.0), + ), + )(x) + return x + + return apply + + +def step_function(x, y, k=50): + return 1.0 / 1 + ops.exp(-k * (x - y)) + + +def DBHead(in_channels, k=50, **kwargs): + def apply(inputs, training): + probability_maps = Head(in_channels, **kwargs)(inputs) + if not training: + return probability_maps + + threshold_maps = Head(in_channels, **kwargs)(inputs) + binary_maps = step_function(probability_maps, threshold_maps) + y = layers.Concatenate(axis=-1)( + [probability_maps, threshold_maps, binary_maps] + ) + return y + + return apply diff --git a/discrete-binarization/train.py b/discrete-binarization/train.py new file mode 100644 index 0000000000..b095396bec --- /dev/null +++ b/discrete-binarization/train.py @@ -0,0 +1,42 @@ +import os + +os.environ["KERAS_BACKEND"] = "jax" +from pathlib import Path + +import keras +from dataset import ICDARPyDataset +from losses import DBLoss +from model import DBHead +from model import FPNModel +from model import get_backbone_model + +if __name__ == "__main__": + base_data_dir = Path("text_localization") + train_images_dir = base_data_dir / "icdar_c4_train_imgs" + test_images_dir = base_data_dir / "ch4_test_images" + train_labels_path = base_data_dir / "train_icdar2015_label.txt" + text_labels_dir = base_data_dir / "test_icdar2015_label.txt" + + dataset = ICDARPyDataset(train_images_dir, train_labels_path) + + input_shape = (None, None, 3) + inputs = keras.Input(shape=input_shape) + x = get_backbone_model(input_shape)(inputs) + x = FPNModel(out_channels=256)(x) + outputs = DBHead(in_channels=256)(x, training=True) + model = keras.Model(inputs=inputs, outputs=outputs) + + db_loss = DBLoss() + + # Forward propagate a single sample + sample = dataset[0] + outputs = model(sample["image"]) + + # Compute the loss + loss = db_loss( + y_true=[sample["shrink_map"], sample["threshold_map"]], + y_pred=outputs, + mask=[sample["shrink_mask"], sample["threshold_mask"]], + ) + + # TODO: Backpropagate the loss diff --git a/keras_cv/layers/object_detection/anchor_generator.py b/keras_cv/layers/object_detection/anchor_generator.py index 30dd421afd..effc125143 100644 --- a/keras_cv/layers/object_detection/anchor_generator.py +++ b/keras_cv/layers/object_detection/anchor_generator.py @@ -172,7 +172,7 @@ def __call__(self, image=None, image_shape=None): "Expected `image` to be a Tensor of rank 3. Got " f"image.shape.rank={len(image.shape)}" ) - image_shape = image.shape + image_shape = tuple(image.shape) results = {} for key, generator in self.anchor_generators.items():