diff --git a/.gitignore b/.gitignore index fe0a5871bf..b9d02176e1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,8 @@ templates/examples/graph/* templates/guides/**/*.md templates/keras-tuner/getting_started.md datasets/* -.vscode/* \ No newline at end of file +.vscode/* +**/*.zip +**/*.tgz +**/*.tar.gz +**/data/* \ No newline at end of file diff --git a/examples/vision/shiftvit.py b/examples/vision/shiftvit.py index e0c42d1d20..1e88dbb0cc 100644 --- a/examples/vision/shiftvit.py +++ b/examples/vision/shiftvit.py @@ -2,7 +2,7 @@ Title: A Vision Transformer without Attention Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/) Date created: 2022/02/24 -Last modified: 2022/10/15 +Last modified: 2024/05/03 Description: A minimal implementation of ShiftViT. Accelerator: GPU """ @@ -26,32 +26,32 @@ In this example, we minimally implement the paper with close alignement to the author's [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py). - -This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can -be installed using the following command: -""" -"""shell -pip install -qq -U tensorflow-addons """ """ ## Setup and imports """ +import os + +if os.environ.get("KERAS_BACKEND") is None: + # @param ["tensorflow", "torch", "jax"] + os.environ["KERAS_BACKEND"] = "tensorflow" + +# Import tensorflow because of tf.data.Dataset dependencies. +import tensorflow as tf + import numpy as np import matplotlib.pyplot as plt -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -import tensorflow_addons as tfa +import keras +from keras import layers +from keras import ops -import pathlib -import glob +from pathlib import Path -# Setting seed for reproducibiltiy -SEED = 42 -keras.utils.set_random_seed(SEED) +# Setting seed for reproducibility. +keras.utils.set_random_seed(seed=42) """ ## Hyperparameters @@ -242,7 +242,7 @@ def build(self, input_shape): [ layers.Dense( units=initial_filters, - activation=tf.nn.gelu, + activation=keras.activations.gelu, ), layers.Dropout(rate=self.mlp_dropout_rate), layers.Dense(units=input_channels), @@ -268,7 +268,7 @@ def call(self, x): class DropPath(layers.Layer): """Drop Path also known as the Stochastic Depth layer. - Refernece: + Reference: - https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization - github.com:rwightman/pytorch-image-models """ @@ -280,9 +280,9 @@ def __init__(self, drop_path_prob, **kwargs): def call(self, x, training=False): if training: keep_prob = 1 - self.drop_path_prob - shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) - random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) - random_tensor = tf.floor(random_tensor) + shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1) + random_tensor = keep_prob + keras.random.uniform(shape) + random_tensor = ops.floor(random_tensor) return (x / keep_prob) * random_tensor return x @@ -396,17 +396,17 @@ def get_shift_pad(self, x, mode): offset_width = 0 target_height = self.shift_pixel target_width = 0 - crop = tf.image.crop_to_bounding_box( + crop = ops.image.crop_images( x, - offset_height=offset_height, - offset_width=offset_width, + top_cropping=offset_height, + left_cropping=offset_width, target_height=self.H - target_height, target_width=self.W - target_width, ) - shift_pad = tf.image.pad_to_bounding_box( + shift_pad = ops.image.pad_images( crop, - offset_height=offset_height, - offset_width=offset_width, + top_padding=offset_height, + left_padding=offset_width, target_height=self.H, target_width=self.W, ) @@ -414,7 +414,7 @@ def get_shift_pad(self, x, mode): def call(self, x, training=False): # Split the feature maps - x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1) + x_splits = ops.split(x, indices_or_sections=self.C // self.num_div, axis=-1) # Shift the feature maps x_splits[0] = self.get_shift_pad(x_splits[0], mode="left") @@ -423,7 +423,7 @@ def call(self, x, training=False): x_splits[3] = self.get_shift_pad(x_splits[3], mode="down") # Concatenate the shifted and unshifted feature maps - x = tf.concat(x_splits, axis=-1) + x = ops.concatenate(x_splits, axis=-1) # Add the residual connection shortcut = x @@ -588,6 +588,9 @@ def get_config(self): """ +@keras.saving.register_keras_serializable( + package="my_shiftvit_package", name="shiftvitmodel" +) class ShiftViTModel(keras.Model): """The ShiftViT Model. @@ -617,6 +620,15 @@ def __init__( epsilon, mlp_dropout_rate, stochastic_depth_rate, + classifier=layers.Dense(config.num_classes), + global_avg_pool=layers.GlobalAveragePooling2D(), + patch_projection=layers.Conv2D( + filters=config.projected_dim, + kernel_size=config.patch_size, + strides=config.patch_size, + padding="same", + ), + stages=list(), num_div=12, shift_pixel=1, mlp_expand_ratio=2, @@ -624,13 +636,20 @@ def __init__( ): super().__init__(**kwargs) self.data_augmentation = data_augmentation - self.patch_projection = layers.Conv2D( - filters=projected_dim, - kernel_size=patch_size, - strides=patch_size, - padding="same", - ) - self.stages = list() + self.patch_projection = patch_projection + self.classifier = classifier + self.global_avg_pool = global_avg_pool + self.projected_dim = (projected_dim,) + self.patch_size = (patch_size,) + self.num_shift_blocks_per_stages = (num_shift_blocks_per_stages,) + self.epsilon = (epsilon,) + self.mlp_dropout_rate = (mlp_dropout_rate,) + self.stochastic_depth_rate = (stochastic_depth_rate,) + self.num_div = (num_div,) + self.shift_pixel = (shift_pixel,) + self.mlp_expand_ratio = (mlp_expand_ratio,) + self.stages = stages + for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages): if index == len(num_shift_blocks_per_stages) - 1: # This is the last stage, do not use the patch merge here. @@ -650,9 +669,6 @@ def __init__( mlp_expand_ratio=mlp_expand_ratio, ) ) - self.global_avg_pool = layers.GlobalAveragePooling2D() - - self.classifier = layers.Dense(config.num_classes) def get_config(self): config = super().get_config() @@ -663,66 +679,19 @@ def get_config(self): "stages": self.stages, "global_avg_pool": self.global_avg_pool, "classifier": self.classifier, + "projected_dim": self.projected_dim, + "patch_size": self.patch_size, + "num_shift_blocks_per_stages": self.num_shift_blocks_per_stages, + "epsilon": self.epsilon, + "mlp_dropout_rate": self.mlp_dropout_rate, + "stochastic_depth_rate": self.stochastic_depth_rate, + "num_div": self.num_div, + "shift_pixel": self.shift_pixel, + "mlp_expand_ratio": self.mlp_expand_ratio, } ) return config - def _calculate_loss(self, data, training=False): - (images, labels) = data - - # Augment the images - augmented_images = self.data_augmentation(images, training=training) - - # Create patches and project the pathces. - projected_patches = self.patch_projection(augmented_images) - - # Pass through the stages - x = projected_patches - for stage in self.stages: - x = stage(x, training=training) - - # Get the logits. - x = self.global_avg_pool(x) - logits = self.classifier(x) - - # Calculate the loss and return it. - total_loss = self.compiled_loss(labels, logits) - return total_loss, labels, logits - - def train_step(self, inputs): - with tf.GradientTape() as tape: - total_loss, labels, logits = self._calculate_loss( - data=inputs, training=True - ) - - # Apply gradients. - train_vars = [ - self.data_augmentation.trainable_variables, - self.patch_projection.trainable_variables, - self.global_avg_pool.trainable_variables, - self.classifier.trainable_variables, - ] - train_vars = train_vars + [stage.trainable_variables for stage in self.stages] - - # Optimize the gradients. - grads = tape.gradient(total_loss, train_vars) - trainable_variable_list = [] - for grad, var in zip(grads, train_vars): - for g, v in zip(grad, var): - trainable_variable_list.append((g, v)) - self.optimizer.apply_gradients(trainable_variable_list) - - # Update the metrics - self.compiled_metrics.update_state(labels, logits) - return {m.name: m.result() for m in self.metrics} - - def test_step(self, data): - _, labels, logits = self._calculate_loss(data=data, training=False) - - # Update the metrics - self.compiled_metrics.update_state(labels, logits) - return {m.name: m.result() for m in self.metrics} - def call(self, images): augmented_images = self.data_augmentation(images) x = self.patch_projection(augmented_images) @@ -762,6 +731,8 @@ def call(self, images): # Some code is taken from: # https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2. +# +# The original implementation has been adapted to use Keras 3 ops. class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): """A LearningRateSchedule that uses a warmup cosine decay schedule.""" @@ -779,7 +750,6 @@ def __init__(self, lr_start, lr_max, warmup_steps, total_steps): self.lr_max = lr_max self.warmup_steps = warmup_steps self.total_steps = total_steps - self.pi = tf.constant(np.pi) def __call__(self, step): # Check whether the total number of steps is larger than the warmup @@ -793,10 +763,10 @@ def __call__(self, step): # `cos_annealed_lr` is a graph that increases to 1 from the initial # step to the warmup step. After that this graph decays to -1 at the # final step mark. - cos_annealed_lr = tf.cos( - self.pi - * (tf.cast(step, tf.float32) - self.warmup_steps) - / tf.cast(self.total_steps - self.warmup_steps, tf.float32) + cos_annealed_lr = ops.cos( + np.pi + * (ops.cast(step, dtype="float32") - self.warmup_steps) + / ops.cast(self.total_steps - self.warmup_steps, dtype="float32") ) # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes @@ -821,20 +791,18 @@ def __call__(self, step): # With the formula for a straight line (y = mx+c) build the warmup # schedule - warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start + warmup_rate = slope * ops.cast(step, dtype="float32") + self.lr_start # When the current step is lesser that warmup steps, get the line # graph. When the current step is greater than the warmup steps, get # the scaled cos graph. - learning_rate = tf.where( + learning_rate = ops.where( step < self.warmup_steps, warmup_rate, learning_rate ) # When the current step is more that the total steps, return 0 else return # the calculated graph. - return tf.where( - step > self.total_steps, 0.0, learning_rate, name="learning_rate" - ) + return ops.where(step > self.total_steps, 0.0, learning_rate) def get_config(self): config = { @@ -871,7 +839,7 @@ def get_config(self): ) # Get the optimizer. -optimizer = tfa.optimizers.AdamW( +optimizer = keras.optimizers.AdamW( learning_rate=scheduled_lrs, weight_decay=config.weight_decay ) @@ -911,9 +879,10 @@ def get_config(self): Since we created the model by Subclassing, we can't save the model in HDF5 format. -It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well. +It can be saved in native Keras format only. In general, this is the recommended +format for saving models in Keras 3. """ -model.save("ShiftViT") +model.save("ShiftViT.keras") """ ## Model inference @@ -934,9 +903,9 @@ def get_config(self): """ # Custom objects are not included when the model is saved. # At loading time, these objects need to be passed for reconstruction of the model -saved_model = tf.keras.models.load_model( - "ShiftViT", - custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW}, +saved_model = keras.saving.load_model( + "ShiftViT.keras", + custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": keras.optimizers.AdamW}, ) """ @@ -944,30 +913,14 @@ def get_config(self): """ -def process_image(img_path): - # read image file from string path - img = tf.io.read_file(img_path) - - # decode jpeg to uint8 tensor - img = tf.io.decode_jpeg(img, channels=3) - - # resize image to match input size accepted by model - # use `method` as `nearest` to preserve dtype of input passed to `resize()` - img = tf.image.resize( - img, [config.input_shape[0], config.input_shape[1]], method="nearest" - ) - return img - - def create_tf_dataset(image_dir): - data_dir = pathlib.Path(image_dir) - - # create tf.data dataset using directory of images - predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False) - - # use map to convert string paths to uint8 image tensors - # setting `num_parallel_calls' helps in processing multiple images parallely - predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO) + predict_ds = keras.utils.image_dataset_from_directory( + directory=Path(image_dir), + image_size=(config.input_shape[0], config.input_shape[1]), + labels=None, + interpolation="nearest", + batch_size=config.tf_ds_batch_size, + ) # create a Prefetch Dataset for better latency & throughput predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO) @@ -979,21 +932,23 @@ def predict(predict_ds): logits = saved_model.predict(predict_ds) # normalize predictions by calling softmax() - probabilities = tf.nn.softmax(logits) + probabilities = ops.softmax(logits) return probabilities def get_predicted_class(probabilities): - pred_label = np.argmax(probabilities) + pred_label = ops.argmax(probabilities).numpy() predicted_class = config.label_map[pred_label] return predicted_class def get_confidence_scores(probabilities): # get the indices of the probability scores sorted in descending order - labels = np.argsort(probabilities)[::-1] + labels = ops.convert_to_numpy(ops.argsort(probabilities)[::-1]) confidences = { - config.label_map[label]: np.round((probabilities[label]) * 100, 2) + config.label_map[label]: ops.convert_to_numpy( + ops.round((probabilities[label]) * 100, 2) + ).item(0) for label in labels } return confidences