diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 76ac0631b4..6bd9d341a6 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -69,6 +69,7 @@ then keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/feature_extractor/coca \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion else @@ -83,6 +84,7 @@ else keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/feature_extractor/coca \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion fi \ No newline at end of file diff --git a/keras_cv/layers/transformer_encoder.py b/keras_cv/layers/transformer_encoder.py index 152fe354f8..7d6674b9d6 100644 --- a/keras_cv/layers/transformer_encoder.py +++ b/keras_cv/layers/transformer_encoder.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from tensorflow import keras -from tensorflow.keras import layers +import keras +from keras import layers from keras_cv.api_export import keras_cv_export diff --git a/keras_cv/models/feature_extractor/coca/__init__.py b/keras_cv/models/feature_extractor/coca/__init__.py new file mode 100644 index 0000000000..5372894aca --- /dev/null +++ b/keras_cv/models/feature_extractor/coca/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.models.feature_extractor.coca.coca_model import CoCa +from keras_cv.models.feature_extractor.coca.coca_layers import CoCaAttentionPooling diff --git a/keras_cv/models/feature_extractor/coca/coca_layers.py b/keras_cv/models/feature_extractor/coca/coca_layers.py new file mode 100644 index 0000000000..25bbbc1a60 --- /dev/null +++ b/keras_cv/models/feature_extractor/coca/coca_layers.py @@ -0,0 +1,39 @@ +from keras import layers + + +class CoCaAttentionPooling(layers.Layer): + """Implements the Pooled Attention Layer used in "coca": Contrastive Captioners are Image-Text Foundation Models" + (https://arxiv.org/pdf/2205.01917.pdf), consisting of a Multiheaded Attention followed by Layer Normalization. + + Args: + head_dim: The dimensions of the attention heads + num_heads: The number of attention heads in the multi-headed attention layer + """ + + def __init__(self, head_dim, num_heads, **kwargs): + super().__init__(**kwargs) + + self.head_dim = head_dim + self.num_heads = num_heads + + self.multi_head_attn = layers.MultiHeadAttention( + self.num_heads, self.head_dim + ) + + self.layer_norm = layers.LayerNormalization() + + def build(self, input_shape): + # super().build(input_shape) + + if(len(input_shape) < 2): + raise ValueError("Building CoCa Attention Pooling requires input shape of shape (query_shape, value_shape)") + + query_shape = input_shape[0] + value_shape = input_shape[1] + + self.multi_head_attn._build_from_signature(query_shape, value_shape) + self.layer_norm.build(query_shape) + + def call(self, query, value): + x = self.multi_head_attn(query, value) + return self.layer_norm(x) diff --git a/keras_cv/models/feature_extractor/coca/coca_model.py b/keras_cv/models/feature_extractor/coca/coca_model.py new file mode 100644 index 0000000000..707cf2eb43 --- /dev/null +++ b/keras_cv/models/feature_extractor/coca/coca_model.py @@ -0,0 +1,279 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import Sequential +from keras_nlp.layers import RotaryEmbedding +from keras_nlp.layers import TransformerDecoder + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import ops +from keras_cv.layers import TransformerEncoder as CVTransformerEncoder +from keras_cv.models.feature_extractor.coca.coca_layers import CoCaAttentionPooling +from keras_cv.layers.vit_layers import PatchingAndEmbedding +from keras_cv.models.task import Task + + +@keras_cv_export(["keras_cv.models.coca"]) +class CoCa(Task): + """Contrastive Captioner foundational model implementation. + + This model implements the "Contrastive Captioners are image-Text Foundational Models" by Yu, et al. + (https://arxiv.org/pdf/2205.01917.pdf). In short, the coca model combines the ideas of Contrastive techniques + such as CLIP, with Generative Captioning approaches such as SimVLM. + + The architecture of clip can be described as an Image Visual Transformer Encoder in parallel to self-attention-only + Text Transformer Decoder, the outputs of both of which are passed into a multimodal Transformer Decoder. The + contrastive loss from the ViT and the uni-modal Text Decoder is combined with a captioning loss from the multi-modal + Decoder in order to produce the combined total loss. + + Basic Usage: + ```python + + images = ... # [batch_size, height, width, channel] + text = ... # [batch_size, text_dim, sequence_length] + + coca = coca() + + # [batch_size, sequence_length, captioning_query_length] + output = coca(images, text) + ``` + + All default arguments should be consistent with the original paper's details. + + Args: + img_shape: The shape of a single image, typically expressed as [height, weight, channels] + caption_shape: The shape of a single caption, typically expressed as [sequence_length, text_dim] + img_patch_size: N of each NxN patch generated from linearization of the input images + encoder_depth: number of image encoder blocks + encoder_heads: number of attention heads used in each image encoder block + encoder_intermediate_dim: dimensionality of the encoder blocks' intermediate representation (MLP dimensionality) + encoder_width: dimensionality of the encoder's projection, consistent with wording used in coca paper. + unimodal_decoder_depth: number of decoder blocks used for text self-attention/embedding + multimodal_decoder_depth: number of decoder blocks used for image-text cross-attention and captioning + decoder_intermediate_dim: dimensionality of the decoder blocks' MLPs + unimodal_decoder_heads: number of attention heads in the unimodal decoder + multimodal_decoder_heads: number of attention heads in the multimodal decoder + contrastive_query_length: number of tokens to use to represent contrastive query + captioning_query_length: number of tokens to use to represent captioning query + contrastive_attn_heads: number of attention heads used for the contrastive attention pooling + captioning_attn_heads: number of attention heads used for the captioning attention pooling + contrastive_loss_weight: weighting of contrastive loss + captioning_loss_weight: weighting of captioning loss + """ + + def __init__( + self, + img_shape=(512, 512, 3), + caption_shape = (10, 48), + img_patch_size=18, + encoder_depth=40, + encoder_heads=16, + encoder_intermediate_dim=6144, + encoder_width=1408, + unimodal_decoder_depth=18, + multimodal_decoder_depth=18, + decoder_intermediate_dim=5632, + unimodal_decoder_heads=16, + multimodal_decoder_heads=16, + contrastive_query_length=1, + captioning_query_length=256, + contrastive_attn_heads=16, + captioning_attn_heads=16, + contrastive_loss_weight=0.5, + captioning_loss_weight=0.5, + **kwargs, + ): + super().__init__(**kwargs) + + # + # Save Details + # + self.img_shape = img_shape + self.caption_shape = caption_shape + + self.img_patch_size = img_patch_size + + self.encoder_depth = encoder_depth + self.encoder_heads = encoder_heads + self.encoder_width = encoder_width + self.encoder_intermediate_dim = encoder_intermediate_dim + + self.unimodal_decoder_depth = unimodal_decoder_depth + self.multimodal_decoder_depth = multimodal_decoder_depth + self.decoder_intermediate_dim = decoder_intermediate_dim + self.unimodal_decoder_heads = unimodal_decoder_heads + self.multimodal_decoder_heads = multimodal_decoder_heads + + self.contrastive_query_length = contrastive_query_length + self.contrastive_attn_heads = contrastive_attn_heads + self.contrastive_loss_weight = contrastive_loss_weight + + self.captioning_query_length = captioning_query_length + self.captioning_attn_heads = captioning_attn_heads + self.captioning_loss_weight = captioning_loss_weight + + # + # Layer Definitions + # + self.image_patching = PatchingAndEmbedding( + self.encoder_width, self.img_patch_size + ) + self.image_encoder = Sequential( + [ + CVTransformerEncoder( + self.encoder_width, + self.encoder_heads, + self.encoder_intermediate_dim, + ) + for _ in range(self.encoder_depth) + ] + ) + + self.text_embedding = RotaryEmbedding() + self.unimodal_text_decoder = Sequential( + [ + TransformerDecoder( + self.decoder_intermediate_dim, self.unimodal_decoder_heads + ) + for _ in range(self.unimodal_decoder_depth) + ] + ) + self.multimodal_text_decoders = [ + TransformerDecoder( + self.decoder_intermediate_dim, self.multimodal_decoder_heads + ) + for _ in range(self.multimodal_decoder_depth) + ] + + self.contrastive_attn_pooling = CoCaAttentionPooling( + self.encoder_width, self.contrastive_attn_heads + ) + self.captioning_attn_pooling = CoCaAttentionPooling( + self.encoder_width, self.captioning_attn_heads + ) + + # These are learnable weights defined in build as per Keras recommendations + self.contrastive_query = None + self.captioning_query = None + + # + # Functional Model + # + images = keras.Input(shape=self.img_shape, name="images") + captions = keras.Input(shape=self.caption_shape, name="caption") + + img_encoding = self.image_patching( + images + ) # [batch_size, img_patches_len+1, encoder_width] + img_encoding = self.image_encoder( + img_encoding + ) # [batch_size, img_patches_len+1, encoder_width] + + # Learnable Weights + self.contrastive_query = self.add_weight( + shape=( + None, + self.contrastive_query_length, + self.encoder_width, + ), + trainable=True, + ) + self.captioning_query = self.add_weight( + shape=( + None, + self.captioning_query_length, + self.encoder_width, + ), + trainable=True, + ) + + # This is for contrastive loss; [batch_size, contrastive_query_length, encoder_width] + contrastive_feature = self.con_attn_pooling(self.contrastive_query, img_encoding) + + # [batch_size, captioning_query_length, encoder_width] + captioning_feature = self.captioning_attn_pooling( + self.captioning_query, img_encoding + ) + + # Learnable CLs Token + self.cls_token = self.add_weight( + shape=(None, 1, self.caption_shape[-1]), name="cls_token", trainable=True + ) + + # [batch_size, sequence_length+1, text_dim] + text_tokens = ops.concatenate(captions, self.cls_token) + mask = ops.concatenate( + (ops.ones_like(captions), ops.zeros_like(self.cls_token)) + ) + + # [batch_size, sequence_length+1, text_dim] + embed_text = self.text_embedding(text_tokens) + unimodal_out = self.unimodal_text_decoder( + embed_text, attention_mask=mask + ) + + # [batch_size, sequence_length, captioning_query_length], notice we remove the CLs token + multimodal_out = unimodal_out[:, :-1, :] + for decoder in self.multimodal_text_decoders: + multimodal_out = decoder( + multimodal_out, + encoder_sequence=captioning_feature, + decoder_attention_mask=mask + ) + + super().__init__( + inputs={ + "images": images, + "captions": captions, + }, + outputs={ + "multimodal_out": multimodal_out, + "contrastive_feature": contrastive_feature + }, + ) + + + def get_config(self): + config = super().get_config() + config.update( + { + "img_shape": self.img_shape, + "caption_shape": self.caption_shape, + "img_patch_size": self.img_patch_size, + "encoder_depth": self.encoder_depth, + "encoder_heads": self.encoder_heads, + "encoder_width": self.encoder_width, + "encoder_intermediate_dim": self.encoder_intermediate_dim, + "unimodal_decoder_depth": self.unimodal_decoder_depth, + "multimodal_decoder_depth": self.multimodal_decoder_depth, + "decoder_intermediate_dim": self.decoder_intermediate_dim, + "unimodal_decoder_heads": self.unimodal_decoder_heads, + "multimodal_decoder_heads": self.multimodal_decoder_heads, + "contrastive_query_length": self.contrastive_query_length, + "contrastive_attn_heads": self.contrastive_attn_heads, + "contrastive_loss_weight": self.contrastive_loss_weight, + "captioning_query_length": self.captioning_query_length, + "captioning_attn_heads": self.captioning_attn_heads, + "captioning_loss_weight": self.captioning_loss_weight, + } + ) + return config + + @classmethod + def from_config(cls, config): + return cls(**config) + + def load_own_variables(self, store): + print(store) + super().load_own_variables(store) \ No newline at end of file diff --git a/keras_cv/models/feature_extractor/coca/coca_model_test.py b/keras_cv/models/feature_extractor/coca/coca_model_test.py new file mode 100644 index 0000000000..f9c99f903e --- /dev/null +++ b/keras_cv/models/feature_extractor/coca/coca_model_test.py @@ -0,0 +1,23 @@ +import keras.saving +import numpy as np +import pytest +import os + +from keras_cv.models.feature_extractor.coca import CoCa +from keras_cv.tests.test_case import TestCase + +class CoCaTest(TestCase): + + @pytest.mark.large + def test_coca_model_save(self): + # TODO: Transformer encoder breaks if you have project dim < num heads + model = CoCa() + + save_path = os.path.join(self.get_temp_dir(), "coca.keras") + model.save(save_path) + + restored_model = keras.models.load_model(save_path, custom_objects={"CoCa": CoCa}) + + self.assertIsInstance(restored_model, CoCa) + +