Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When Recurrence meets Transformers to keras 3.0 (Tensorflow backend only) #1984

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 87 additions & 55 deletions examples/vision/temporal_latent_bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: When Recurrence meets Transformers
Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
Date created: 2023/03/12
Last modified: 2023/03/12
Last modified: 2024/10/29
Description: Image Classification with Temporal Latent Bottleneck Networks.
Accelerator: GPU
"""
Expand Down Expand Up @@ -51,21 +51,20 @@
"""
## Setup imports
"""
import os

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import AdamW
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers, ops, mixed_precision
from keras.optimizers import AdamW
import numpy as np
import random
from matplotlib import pyplot as plt

# Set seed for reproducibility.
keras.utils.set_random_seed(42)

AUTO = tf.data.AUTOTUNE

"""
## Setting required configuration

Expand Down Expand Up @@ -184,41 +183,49 @@ def test_map_fn(image, label):


"""
## Load dataset into `tf.data.Dataset` object
## Load dataset into `PyDataset` object

- We take the `np.ndarray` instance of the datasets and move them into a
`tf.data.Dataset` instance
- Apply augmentations using
[`.map()`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map)
- Shuffle the dataset using
[`.shuffle()`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle)
- Batch the dataset using
[`.batch()`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch)
- Enable pre-fetching of batches using
[`.prefetch()`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch)
- We take the `np.ndarray` instance of the datasets and wrap a class around it,
wrapping a a `keras.utils.PyDataset` and apply augmentations with keras
preprocessing layers.
"""

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
train_ds.map(train_map_fn, num_parallel_calls=AUTO)
.shuffle(config["buffer_size"])
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
val_ds.map(test_map_fn, num_parallel_calls=AUTO)
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
test_ds.map(test_map_fn, num_parallel_calls=AUTO)
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
class Dataset(keras.utils.PyDataset):
def __init__(
self, x_data, y_data, batch_size, preprocess_fn=None, shuffle=False, **kwargs
):
if shuffle:
perm = np.random.permutation(len(x_data))
x_data = x_data[perm]
y_data = y_data[perm]
self.x_data = x_data
self.y_data = y_data
self.preprocess_fn = preprocess_fn
self.batch_size = batch_size
super().__init__(*kwargs)

def __len__(self):
return len(self.x_data) // self.batch_size

def __getitem__(self, idx):
batch_x, batch_y = [], []
for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
x, y = self.x_data[i], self.y_data[i]
if self.preprocess_fn:
x, y = self.preprocess_fn(x, y)
batch_x.append(x)
batch_y.append(y)
batch_x = np.stack(batch_x, axis=0)
batch_y = np.stack(batch_y, axis=0)
return batch_x, batch_y


train_ds = Dataset(
x_train, y_train, config["batch_size"], preprocess_fn=train_map_fn, shuffle=True
)
val_ds = Dataset(x_val, y_val, config["batch_size"], preprocess_fn=test_map_fn)
test_ds = Dataset(x_test, y_test, config["batch_size"], preprocess_fn=test_map_fn)

"""
## Temporal Latent Bottleneck
Expand Down Expand Up @@ -310,7 +317,7 @@ def __init__(
self.num_patches = patch_resolution[0] * patch_resolution[1]

# Define the positions of the patches.
self.positions = tf.range(start=0, limit=self.num_patches, delta=1)
self.positions = ops.arange(start=0, stop=self.num_patches, step=1)

# Create the layers.
self.projection = layers.Conv2D(
Expand Down Expand Up @@ -375,7 +382,7 @@ def __init__(self, dims, dropout, **kwargs):
# Create the layers.
self.ffn = keras.Sequential(
[
layers.Dense(units=4 * dims, activation=tf.nn.gelu),
layers.Dense(units=4 * dims, activation="gelu"),
layers.Dense(units=dims),
layers.Dropout(rate=dropout),
],
Expand Down Expand Up @@ -484,7 +491,13 @@ def __init__(
):
super().__init__(**kwargs)
# Create the layers.
self.attention = BaseAttention(
self.fast_stream_attention = BaseAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=attn_dropout,
name="base_attn",
)
self.slow_stream_attention = BaseAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=attn_dropout,
Expand All @@ -498,12 +511,25 @@ def __init__(

self.attention_scores = None

def call(self, query, key, value):
def build(self, input_shape):
self.built = True

def call(self, query, key, value, stream="fast"):
# Apply the attention module.
x = self.attention(query, key, value)
attention_layer = {
"fast": self.fast_stream_attention,
"slow": self.slow_stream_attention,
}[stream]
if len(query.shape) == 2:
query = ops.expand_dims(query, -1)
if len(key.shape) == 2:
key = ops.expand_dims(key, -1)
if len(value.shape) == 2:
value = ops.expand_dims(value, -1)
x = attention_layer(query, key, value)

# Save the attention scores for later visualization.
self.attention_scores = self.attention.attention_scores
self.attention_scores = attention_layer.attention_scores

# Apply the FFN.
x = self.ffn(x)
Expand Down Expand Up @@ -577,10 +603,9 @@ def __init__(
self.key_dim = key_dim
self.attn_dropout = attn_dropout

# Create the state_size and output_size. This is important for
# Create state_size. This is important for
# custom recurrence logic.
self.state_size = tf.TensorShape([chunk_size, ffn_dims])
self.output_size = tf.TensorShape([chunk_size, ffn_dims])
self.state_size = chunk_size * ffn_dims

self.get_attention_scores = False
self.attention_scores = []
Expand Down Expand Up @@ -621,18 +646,23 @@ def __init__(
name=f"tlb_cross_attn_ffn",
)

def build(self, input_shape):
self.built = True

def call(self, inputs, states):
# inputs => (batch, chunk_size, dims)
# states => [(batch, chunk_size, units)]
slow_stream = states[0]
slow_stream = ops.reshape(states[0], (-1, self.chunk_size, self.ffn_dims))
fast_stream = inputs

for layer_idx, layer in enumerate(self.perceptual_module):
fast_stream = layer(query=fast_stream, key=fast_stream, value=fast_stream)
fast_stream = layer(
query=fast_stream, key=fast_stream, value=fast_stream, stream="fast"
)

if layer_idx % self.r == 0:
fast_stream = layer(
query=fast_stream, key=slow_stream, value=slow_stream
query=fast_stream, key=slow_stream, value=slow_stream, stream="slow"
)

slow_stream = self.tlb_module(
Expand All @@ -643,7 +673,9 @@ def call(self, inputs, states):
if self.get_attention_scores:
self.attention_scores.append(self.tlb_module.attention_scores)

return fast_stream, [slow_stream]
return fast_stream, [
ops.reshape(slow_stream, (-1, self.chunk_size * self.ffn_dims))
]


"""
Expand Down Expand Up @@ -773,9 +805,9 @@ def call(self, inputs):

def score_to_viz(chunk_score):
# get the most attended token
chunk_viz = tf.math.reduce_max(chunk_score, axis=-2)
chunk_viz = ops.max(chunk_score, axis=-2)
# get the mean across heads
chunk_viz = tf.math.reduce_mean(chunk_viz, axis=1)
chunk_viz = ops.mean(chunk_viz, axis=1)
return chunk_viz


Expand All @@ -792,8 +824,8 @@ def score_to_viz(chunk_score):

# Process the attention scores in order to visualize them
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores]
chunk_viz = tf.concat(list_chunk_viz[1:], axis=-1)
chunk_viz = tf.reshape(
chunk_viz = ops.concatenate(list_chunk_viz[1:], axis=-1)
chunk_viz = ops.reshape(
chunk_viz,
(
config["batch_size"],
Expand Down
Loading