Skip to content

Commit

Permalink
Update Handwritten Recognition example to keras version3 (#1916)
Browse files Browse the repository at this point in the history
* Update Handwritten Recognition example to keras version3

* Update with .py file on keras3 changes

* Updated with .py file on keras3 changes

* Update .py file with reformatting changes

* .py file with reformatting changes

* Replace modified handwritten_recognition.ipynb and handwritten_recognition.py

* handwriting_recognition.py reformatted

* Reformatted .py file

* Updated py and ipynb file changes

* Update .md file
  • Loading branch information
mehtamansi29 authored Oct 23, 2024
1 parent 695a0b5 commit 2aca792
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 163 deletions.
55 changes: 30 additions & 25 deletions examples/vision/handwriting_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Handwriting recognition
Authors: [A_K_Nain](https://twitter.com/A_K_Nain), [Sayak Paul](https://twitter.com/RisingSayak)
Date created: 2021/08/16
Last modified: 2023/07/06
Last modified: 2024/09/01
Description: Training a handwriting recognition model with variable-length sequences.
Accelerator: GPU
"""
Expand Down Expand Up @@ -45,16 +45,16 @@
## Imports
"""

from tensorflow.keras.layers import StringLookup
from tensorflow import keras

import keras
from keras.layers import StringLookup
from keras import ops
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os

np.random.seed(42)
tf.random.set_seed(42)
keras.utils.set_random_seed(42)

"""
## Dataset splitting
Expand Down Expand Up @@ -213,8 +213,8 @@ def distortion_free_resize(image, img_size):
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)

# Check tha amount of padding needed to be done.
pad_height = h - tf.shape(image)[0]
pad_width = w - tf.shape(image)[1]
pad_height = h - ops.shape(image)[0]
pad_width = w - ops.shape(image)[1]

# Only necessary if you want to do same amount of padding on both sides.
if pad_height % 2 != 0:
Expand All @@ -240,7 +240,7 @@ def distortion_free_resize(image, img_size):
],
)

image = tf.transpose(image, perm=[1, 0, 2])
image = ops.transpose(image, (1, 0, 2))
image = tf.image.flip_left_right(image)
return image

Expand All @@ -267,13 +267,13 @@ def preprocess_image(image_path, img_size=(image_width, image_height)):
image = tf.io.read_file(image_path)
image = tf.image.decode_png(image, 1)
image = distortion_free_resize(image, img_size)
image = tf.cast(image, tf.float32) / 255.0
image = ops.cast(image, tf.float32) / 255.0
return image


def vectorize_label(label):
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
length = tf.shape(label)[0]
length = ops.shape(label)[0]
pad_amount = max_len - length
label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
return label
Expand Down Expand Up @@ -312,7 +312,7 @@ def prepare_dataset(image_paths, labels):
for i in range(16):
img = images[i]
img = tf.image.flip_left_right(img)
img = tf.transpose(img, perm=[1, 0, 2])
img = ops.transpose(img, (1, 0, 2))
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
img = img[:, :, 0]

Expand Down Expand Up @@ -346,15 +346,15 @@ def prepare_dataset(image_paths, labels):
class CTCLayer(keras.layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.backend.ctc_batch_cost
self.loss_fn = tf.keras.backend.ctc_batch_cost

def call(self, y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")

input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)

Expand Down Expand Up @@ -455,14 +455,14 @@ def build_model():

def calculate_edit_distance(labels, predictions):
# Get a single batch and convert its labels to sparse tensors.
saprse_labels = tf.cast(tf.sparse.from_dense(labels), dtype=tf.int64)
saprse_labels = ops.cast(tf.sparse.from_dense(labels), dtype=tf.int64)

# Make predictions and convert them to sparse tensors.
input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
predictions_decoded = keras.backend.ctc_decode(
predictions, input_length=input_len, greedy=True
predictions_decoded = keras.ops.nn.ctc_decode(
predictions, sequence_lengths=input_len
)[0][0][:, :max_len]
sparse_predictions = tf.cast(
sparse_predictions = ops.cast(
tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
)

Expand Down Expand Up @@ -501,7 +501,7 @@ def on_epoch_end(self, epoch, logs=None):

model = build_model()
prediction_model = keras.models.Model(
model.get_layer(name="image").input, model.get_layer(name="dense2").output
model.get_layer(name="image").output, model.get_layer(name="dense2").output
)
edit_distance_callback = EditDistanceCallback(prediction_model)

Expand All @@ -523,14 +523,19 @@ def on_epoch_end(self, epoch, logs=None):
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search.
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
results = keras.ops.nn.ctc_decode(pred, sequence_lengths=input_len)[0][0][
:, :max_len
]
# Iterate over the results and get back the text.
output_text = []
for res in results:
res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
res = (
tf.strings.reduce_join(num_to_char(res))
.numpy()
.decode("utf-8")
.replace("[UNK]", "")
)
output_text.append(res)
return output_text

Expand All @@ -546,7 +551,7 @@ def decode_batch_predictions(pred):
for i in range(16):
img = batch_images[i]
img = tf.image.flip_left_right(img)
img = tf.transpose(img, perm=[1, 0, 2])
img = ops.transpose(img, (1, 0, 2))
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
img = img[:, :, 0]

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 2aca792

Please sign in to comment.