From e603967f82d0d5c204ed00dfd6643226cc148305 Mon Sep 17 00:00:00 2001 From: sanskarmodi8 Date: Thu, 11 Sep 2025 18:57:08 +0530 Subject: [PATCH 1/3] added backbone and its test for voxtral model --- keras_hub/api/models/__init__.py | 3 + .../src/models/voxtral/voxtral_backbone.py | 207 ++++++++++++++++++ .../models/voxtral/voxtral_backbone_test.py | 44 ++++ 3 files changed, 254 insertions(+) create mode 100644 keras_hub/src/models/voxtral/voxtral_backbone.py create mode 100644 keras_hub/src/models/voxtral/voxtral_backbone_test.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index fe220e2d43..a438b047fa 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -677,6 +677,9 @@ from keras_hub.src.models.vit_det.vit_det_backbone import ( ViTDetBackbone as ViTDetBackbone, ) +from keras_hub.src.models.voxtral.voxtral_backbone import ( + VoxTralBackbone as VoxTralBackbone, +) from keras_hub.src.models.whisper.whisper_backbone import ( WhisperBackbone as WhisperBackbone, ) diff --git a/keras_hub/src/models/voxtral/voxtral_backbone.py b/keras_hub/src/models/voxtral/voxtral_backbone.py new file mode 100644 index 0000000000..0e47e3df37 --- /dev/null +++ b/keras_hub/src/models/voxtral/voxtral_backbone.py @@ -0,0 +1,207 @@ +import tensorflow as tf +from keras import initializers +from keras import layers +from keras import mixed_precision + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder +from keras_hub.src.models.backbone import Backbone + + +def voxtral_kernel_initializer(stddev=0.02): + """Initializer for VoxTral layers (TruncatedNormal).""" + return initializers.TruncatedNormal(stddev=stddev) + + +class ChunkAndPad(layers.Layer): + """Pads and splits spectrogram into fixed-length chunks.""" + + def __init__(self, frames_per_chunk, **kwargs): + super().__init__(**kwargs) + self.frames_per_chunk = int(frames_per_chunk) + + def call(self, x): + B, T = tf.shape(x)[0], tf.shape(x)[1] + pad_len = (-T) % self.frames_per_chunk + x = tf.pad(x, [[0, 0], [0, pad_len], [0, 0]]) + n_chunks = tf.math.floordiv(T + pad_len, self.frames_per_chunk) + return tf.reshape( + x, [B * n_chunks, self.frames_per_chunk, tf.shape(x)[2]] + ) + + +class PositionalEmbedding(layers.Layer): + """Learnable positional embedding per chunk.""" + + def __init__(self, length, dim, **kwargs): + super().__init__(**kwargs) + self.length = int(length) + self.dim = int(dim) + + def build(self, input_shape): + self.pos_emb = self.add_weight( + name="pos_emb", + shape=(self.length, self.dim), + initializer=initializers.RandomNormal(stddev=0.02), + trainable=True, + ) + super().build(input_shape) + + def call(self, x): + return x + self.pos_emb[None, :, :] + + +class ReassembleChunks(layers.Layer): + """Reassembles chunked outputs back into (B, T, H).""" + + def __init__(self, frames_per_chunk, postproc_chunk_len=None, **kwargs): + super().__init__(**kwargs) + self.frames_per_chunk = int(frames_per_chunk) + self.postproc_chunk_len = postproc_chunk_len + + def call(self, processed_chunks, orig_spectrogram): + B, T = tf.shape(orig_spectrogram)[0], tf.shape(orig_spectrogram)[1] + n_chunks = tf.cast( + tf.math.floordiv( + T + self.frames_per_chunk - 1, self.frames_per_chunk + ), + tf.int32, + ) + T_chunk, H = ( + tf.shape(processed_chunks)[1], + tf.shape(processed_chunks)[2], + ) + return tf.reshape(processed_chunks, [B, n_chunks * T_chunk, H]) + + +@keras_hub_export("keras_hub.models.VoxTralBackbone") +class VoxTralBackbone(Backbone): + """VoxTral audio encoder + adapter backbone.""" + + def __init__( + self, + num_layers=32, + num_heads=20, + hidden_dim=1280, + intermediate_dim=5120, + adapter_downsample=4, + dropout=0.1, + max_chunk_seconds=30, + sr=16000, + hop_length=160, + dtype="float32", + **kwargs, + ): + self.num_layers = int(num_layers) + self.num_heads = int(num_heads) + self.hidden_dim = int(hidden_dim) + self.intermediate_dim = int(intermediate_dim) + self.adapter_downsample = int(adapter_downsample) + self.dropout = float(dropout) + self.max_chunk_seconds = int(max_chunk_seconds) + self.sr = int(sr) + self.hop_length = int(hop_length) + + # Frames per chunk before conv + self.frames_per_chunk_preconv = int( + self.max_chunk_seconds * (self.sr / self.hop_length) + ) + self.postconv_frames_per_chunk = self.frames_per_chunk_preconv // 2 + + # Determine layer dtype for mixed precision + if isinstance(dtype, mixed_precision.Policy): + self.layer_dtype = dtype.compute_dtype + else: + self.layer_dtype = dtype + + # Conv1D stem + self.conv_stem_1 = layers.Conv1D( + filters=self.hidden_dim, + kernel_size=3, + strides=2, + padding="same", + activation="relu", + kernel_initializer=voxtral_kernel_initializer(), + dtype=self.layer_dtype, + name="conv_stem_1", + ) + self.conv_stem_2 = layers.Conv1D( + filters=self.hidden_dim, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer=voxtral_kernel_initializer(), + dtype=self.layer_dtype, + name="conv_stem_2", + ) + + # Transformer layers + self.transformer_layers = [ + TransformerEncoder( + num_heads=self.num_heads, + intermediate_dim=self.intermediate_dim, + dropout=self.dropout, + name=f"transformer_layer_{i}", + ) + for i in range(self.num_layers) + ] + + # Adapter + self.adapter_dense = layers.Dense( + self.hidden_dim, + activation="relu", + kernel_initializer=voxtral_kernel_initializer(), + dtype=self.layer_dtype, + name="adapter_dense", + ) + self.adapter_pool = layers.AveragePooling1D( + pool_size=self.adapter_downsample, + strides=self.adapter_downsample, + padding="valid", + name="adapter_downsample", + ) + + # Positional embeddings + self.pos_emb = PositionalEmbedding( + self.postconv_frames_per_chunk, self.hidden_dim, name="pos_emb" + ) + + # Functional model + spectrogram_input = tf.keras.Input( + shape=(None, 128), dtype=self.layer_dtype, name="spectrogram" + ) + x = ChunkAndPad(self.frames_per_chunk_preconv, name="chunk_and_pad")( + spectrogram_input + ) + x = self.conv_stem_1(x) + x = self.conv_stem_2(x) + x = self.pos_emb(x) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x) + x = self.adapter_dense(x) + x = self.adapter_pool(x) + outputs = ReassembleChunks( + self.frames_per_chunk_preconv, name="reassemble_chunks" + )(x, spectrogram_input) + + super().__init__( + inputs=spectrogram_input, outputs=outputs, dtype=dtype, **kwargs + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "adapter_downsample": self.adapter_downsample, + "dropout": self.dropout, + "max_chunk_seconds": self.max_chunk_seconds, + "sr": self.sr, + "hop_length": self.hop_length, + } + ) + return config diff --git a/keras_hub/src/models/voxtral/voxtral_backbone_test.py b/keras_hub/src/models/voxtral/voxtral_backbone_test.py new file mode 100644 index 0000000000..6186ff3275 --- /dev/null +++ b/keras_hub/src/models/voxtral/voxtral_backbone_test.py @@ -0,0 +1,44 @@ +import pytest +from keras import mixed_precision +from keras import ops + +from keras_hub.src.models.voxtral.voxtral_backbone import VoxTralBackbone +from keras_hub.src.tests.test_case import TestCase + + +class VoxTralBackboneTest(TestCase): + """Unit tests for VoxTralBackbone.""" + + def setUp(self): + """Initialize default backbone arguments and input data.""" + self.init_kwargs = { + "num_layers": 2, + "num_heads": 2, + "hidden_dim": 16, + "intermediate_dim": 32, + "adapter_downsample": 2, + "dropout": 0.0, + "max_chunk_seconds": 1, + "sr": 16000, + "hop_length": 160, + "dtype": "float32", + } + # Dummy input: shape (batch, time, features) + self.input_data = ops.ones((1, 2542, 128), dtype="float32") + + def test_backbone_basics(self): + """Test forward pass and output shape with float32.""" + mixed_precision.set_global_policy("float32") + model = VoxTralBackbone(**self.init_kwargs) + output = model(self.input_data) + assert tuple(output.shape) == (1, 650, 16) + assert output.dtype.name == "float32" + + @pytest.mark.large + def test_saved_model(self): + """Test saving and loading the model.""" + self.run_model_saving_test( + cls=VoxTralBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) From 4572cef469690068370008fd15cd014c4b8687e9 Mon Sep 17 00:00:00 2001 From: sanskarmodi8 Date: Fri, 12 Sep 2025 14:47:20 +0530 Subject: [PATCH 2/3] resolved some issues in the code as reviewed by gemini-bot --- .../src/models/voxtral/voxtral_backbone.py | 135 +++++++++++++----- .../models/voxtral/voxtral_backbone_test.py | 11 +- 2 files changed, 104 insertions(+), 42 deletions(-) diff --git a/keras_hub/src/models/voxtral/voxtral_backbone.py b/keras_hub/src/models/voxtral/voxtral_backbone.py index 0e47e3df37..ee80622882 100644 --- a/keras_hub/src/models/voxtral/voxtral_backbone.py +++ b/keras_hub/src/models/voxtral/voxtral_backbone.py @@ -1,7 +1,7 @@ -import tensorflow as tf +from keras import Input from keras import initializers from keras import layers -from keras import mixed_precision +from keras import ops from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder @@ -14,24 +14,38 @@ def voxtral_kernel_initializer(stddev=0.02): class ChunkAndPad(layers.Layer): - """Pads and splits spectrogram into fixed-length chunks.""" + """Pads and splits spectrogram into fixed-length chunks. + + Args: + frames_per_chunk: int. Number of frames per chunk. + """ def __init__(self, frames_per_chunk, **kwargs): super().__init__(**kwargs) self.frames_per_chunk = int(frames_per_chunk) def call(self, x): - B, T = tf.shape(x)[0], tf.shape(x)[1] + B, T = ops.shape(x)[0], ops.shape(x)[1] pad_len = (-T) % self.frames_per_chunk - x = tf.pad(x, [[0, 0], [0, pad_len], [0, 0]]) - n_chunks = tf.math.floordiv(T + pad_len, self.frames_per_chunk) - return tf.reshape( - x, [B * n_chunks, self.frames_per_chunk, tf.shape(x)[2]] + x = ops.pad(x, [[0, 0], [0, pad_len], [0, 0]]) + n_chunks = ops.floor_divide(T + pad_len, self.frames_per_chunk) + return ops.reshape( + x, [B * n_chunks, self.frames_per_chunk, ops.shape(x)[2]] ) + def get_config(self): + config = super().get_config() + config.update({"frames_per_chunk": self.frames_per_chunk}) + return config + class PositionalEmbedding(layers.Layer): - """Learnable positional embedding per chunk.""" + """Learnable positional embedding per chunk. + + Args: + length: int. Sequence length. + dim: int. Embedding dimension. + """ def __init__(self, length, dim, **kwargs): super().__init__(**kwargs) @@ -44,39 +58,86 @@ def build(self, input_shape): shape=(self.length, self.dim), initializer=initializers.RandomNormal(stddev=0.02), trainable=True, + dtype=self.compute_dtype, ) super().build(input_shape) def call(self, x): - return x + self.pos_emb[None, :, :] + # Cast embedding to input dtype to avoid float16/32 mismatch + return x + ops.cast(self.pos_emb[None, :, :], x.dtype) + + def get_config(self): + config = super().get_config() + config.update({"length": self.length, "dim": self.dim}) + return config class ReassembleChunks(layers.Layer): - """Reassembles chunked outputs back into (B, T, H).""" + """Reassembles chunked outputs back into (B, T, H). + + Args: + frames_per_chunk: int. Frames per chunk pre-conv. + postproc_chunk_len: Optional[int]. Post-processing chunk length. + """ def __init__(self, frames_per_chunk, postproc_chunk_len=None, **kwargs): super().__init__(**kwargs) self.frames_per_chunk = int(frames_per_chunk) - self.postproc_chunk_len = postproc_chunk_len + self.postproc_chunk_len = ( + None if postproc_chunk_len is None else int(postproc_chunk_len) + ) def call(self, processed_chunks, orig_spectrogram): - B, T = tf.shape(orig_spectrogram)[0], tf.shape(orig_spectrogram)[1] - n_chunks = tf.cast( - tf.math.floordiv( + B, T = ops.shape(orig_spectrogram)[0], ops.shape(orig_spectrogram)[1] + n_chunks = ops.cast( + ops.floor_divide( T + self.frames_per_chunk - 1, self.frames_per_chunk ), - tf.int32, + "int32", ) T_chunk, H = ( - tf.shape(processed_chunks)[1], - tf.shape(processed_chunks)[2], + ops.shape(processed_chunks)[1], + ops.shape(processed_chunks)[2], + ) + return ops.reshape(processed_chunks, [B, n_chunks * T_chunk, H]) + + def get_config(self): + config = super().get_config() + config.update( + { + "frames_per_chunk": self.frames_per_chunk, + "postproc_chunk_len": self.postproc_chunk_len, + } ) - return tf.reshape(processed_chunks, [B, n_chunks * T_chunk, H]) + return config @keras_hub_export("keras_hub.models.VoxTralBackbone") class VoxTralBackbone(Backbone): - """VoxTral audio encoder + adapter backbone.""" + """VoxTral audio encoder + adapter backbone. + + This model implements the encoder portion of the VoxTral model. It takes + a log-Mel spectrogram and produces a sequence of hidden states. + + Args: + num_layers: int, number of transformer layers. + num_heads: int, number of attention heads. + hidden_dim: int, embedding size. + intermediate_dim: int, size of feedforward network hidden layer. + adapter_downsample: int, pooling factor after adapter dense. + dropout: float, dropout probability. + max_chunk_seconds: int, length of chunking in seconds. + sr: int, sample rate. + hop_length: int, hop length for spectrogram frames. + dtype: str or mixed_precision.Policy, dtype for layers. + + Example: + ```python + from keras_hub.models import VoxTralBackbone + model = VoxTralBackbone() + output = model(input_tensor) + ``` + """ def __init__( self, @@ -89,7 +150,7 @@ def __init__( max_chunk_seconds=30, sr=16000, hop_length=160, - dtype="float32", + dtype=None, **kwargs, ): self.num_layers = int(num_layers) @@ -108,12 +169,6 @@ def __init__( ) self.postconv_frames_per_chunk = self.frames_per_chunk_preconv // 2 - # Determine layer dtype for mixed precision - if isinstance(dtype, mixed_precision.Policy): - self.layer_dtype = dtype.compute_dtype - else: - self.layer_dtype = dtype - # Conv1D stem self.conv_stem_1 = layers.Conv1D( filters=self.hidden_dim, @@ -122,7 +177,7 @@ def __init__( padding="same", activation="relu", kernel_initializer=voxtral_kernel_initializer(), - dtype=self.layer_dtype, + dtype=dtype, name="conv_stem_1", ) self.conv_stem_2 = layers.Conv1D( @@ -132,7 +187,7 @@ def __init__( padding="same", activation="relu", kernel_initializer=voxtral_kernel_initializer(), - dtype=self.layer_dtype, + dtype=dtype, name="conv_stem_2", ) @@ -143,6 +198,7 @@ def __init__( intermediate_dim=self.intermediate_dim, dropout=self.dropout, name=f"transformer_layer_{i}", + dtype=dtype, ) for i in range(self.num_layers) ] @@ -152,7 +208,7 @@ def __init__( self.hidden_dim, activation="relu", kernel_initializer=voxtral_kernel_initializer(), - dtype=self.layer_dtype, + dtype=dtype, name="adapter_dense", ) self.adapter_pool = layers.AveragePooling1D( @@ -160,20 +216,25 @@ def __init__( strides=self.adapter_downsample, padding="valid", name="adapter_downsample", + dtype=dtype, ) # Positional embeddings self.pos_emb = PositionalEmbedding( - self.postconv_frames_per_chunk, self.hidden_dim, name="pos_emb" + self.postconv_frames_per_chunk, + self.hidden_dim, + name="pos_emb", + dtype=dtype, ) # Functional model - spectrogram_input = tf.keras.Input( - shape=(None, 128), dtype=self.layer_dtype, name="spectrogram" - ) - x = ChunkAndPad(self.frames_per_chunk_preconv, name="chunk_and_pad")( - spectrogram_input + spectrogram_input = Input( + shape=(None, 128), dtype="float32", name="spectrogram" ) + + x = ChunkAndPad( + self.frames_per_chunk_preconv, name="chunk_and_pad", dtype="float32" + )(spectrogram_input) x = self.conv_stem_1(x) x = self.conv_stem_2(x) x = self.pos_emb(x) @@ -182,7 +243,7 @@ def __init__( x = self.adapter_dense(x) x = self.adapter_pool(x) outputs = ReassembleChunks( - self.frames_per_chunk_preconv, name="reassemble_chunks" + self.frames_per_chunk_preconv, name="reassemble_chunks", dtype=dtype )(x, spectrogram_input) super().__init__( diff --git a/keras_hub/src/models/voxtral/voxtral_backbone_test.py b/keras_hub/src/models/voxtral/voxtral_backbone_test.py index 6186ff3275..03f792f69e 100644 --- a/keras_hub/src/models/voxtral/voxtral_backbone_test.py +++ b/keras_hub/src/models/voxtral/voxtral_backbone_test.py @@ -21,7 +21,6 @@ def setUp(self): "max_chunk_seconds": 1, "sr": 16000, "hop_length": 160, - "dtype": "float32", } # Dummy input: shape (batch, time, features) self.input_data = ops.ones((1, 2542, 128), dtype="float32") @@ -29,10 +28,12 @@ def setUp(self): def test_backbone_basics(self): """Test forward pass and output shape with float32.""" mixed_precision.set_global_policy("float32") - model = VoxTralBackbone(**self.init_kwargs) - output = model(self.input_data) - assert tuple(output.shape) == (1, 650, 16) - assert output.dtype.name == "float32" + self.run_backbone_test( + cls=VoxTralBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(1, 650, 16), + ) @pytest.mark.large def test_saved_model(self): From 19448c1acd5f67a0d8ef82ab09fd00f33820031f Mon Sep 17 00:00:00 2001 From: sanskarmodi8 Date: Fri, 12 Sep 2025 16:27:51 +0530 Subject: [PATCH 3/3] updated docstrings and fixed some issues --- .../src/models/voxtral/voxtral_backbone.py | 176 +++++++++++++----- 1 file changed, 130 insertions(+), 46 deletions(-) diff --git a/keras_hub/src/models/voxtral/voxtral_backbone.py b/keras_hub/src/models/voxtral/voxtral_backbone.py index ee80622882..4dab3eff5d 100644 --- a/keras_hub/src/models/voxtral/voxtral_backbone.py +++ b/keras_hub/src/models/voxtral/voxtral_backbone.py @@ -1,6 +1,7 @@ from keras import Input from keras import initializers from keras import layers +from keras import mixed_precision from keras import ops from keras_hub.src.api_export import keras_hub_export @@ -9,15 +10,28 @@ def voxtral_kernel_initializer(stddev=0.02): - """Initializer for VoxTral layers (TruncatedNormal).""" + """ + Create a TruncatedNormal initializer for VoxTral layers. + + Args: + stddev (float): Standard deviation of the truncated normal distribution. + + Returns: + keras.initializers.Initializer: Truncated normal initializer. + """ return initializers.TruncatedNormal(stddev=stddev) class ChunkAndPad(layers.Layer): - """Pads and splits spectrogram into fixed-length chunks. + """ + Pads and splits an input spectrogram into fixed-length chunks. + + This layer ensures the time axis of the input is divisible by + `frames_per_chunk` by padding zeros, then reshapes into + `(batch * n_chunks, frames_per_chunk, features)`. Args: - frames_per_chunk: int. Number of frames per chunk. + frames_per_chunk (int): Number of frames per chunk. """ def __init__(self, frames_per_chunk, **kwargs): @@ -25,6 +39,7 @@ def __init__(self, frames_per_chunk, **kwargs): self.frames_per_chunk = int(frames_per_chunk) def call(self, x): + """Pad and chunk the input tensor along time dimension.""" B, T = ops.shape(x)[0], ops.shape(x)[1] pad_len = (-T) % self.frames_per_chunk x = ops.pad(x, [[0, 0], [0, pad_len], [0, 0]]) @@ -33,6 +48,29 @@ def call(self, x): x, [B * n_chunks, self.frames_per_chunk, ops.shape(x)[2]] ) + def compute_output_shape(self, input_shape): + """ + Compute static output shape for Keras/JAX backends. + + Args: + input_shape (tuple): (batch, time, features). + + Returns: + tuple: (batch * n_chunks, frames_per_chunk, features) + """ + batch, time, feat = input_shape + if time is None: + n_chunks = None + else: + import math + + n_chunks = math.ceil(time / self.frames_per_chunk) + return ( + None if batch is None else batch * n_chunks, + self.frames_per_chunk, + feat, + ) + def get_config(self): config = super().get_config() config.update({"frames_per_chunk": self.frames_per_chunk}) @@ -40,11 +78,13 @@ def get_config(self): class PositionalEmbedding(layers.Layer): - """Learnable positional embedding per chunk. + """ + Learnable positional embedding added to each time step in a chunk. Args: - length: int. Sequence length. - dim: int. Embedding dimension. + length (int): Sequence length of each chunk + (frames per chunk post-conv). + dim (int): Embedding dimension. """ def __init__(self, length, dim, **kwargs): @@ -53,6 +93,7 @@ def __init__(self, length, dim, **kwargs): self.dim = int(dim) def build(self, input_shape): + """Create the embedding weights.""" self.pos_emb = self.add_weight( name="pos_emb", shape=(self.length, self.dim), @@ -63,9 +104,21 @@ def build(self, input_shape): super().build(input_shape) def call(self, x): - # Cast embedding to input dtype to avoid float16/32 mismatch + """ + Add the positional embedding to the input tensor. + + Args: + x (Tensor): Input tensor of shape (batch, time, dim). + + Returns: + Tensor: Input tensor with positional embedding added. + """ return x + ops.cast(self.pos_emb[None, :, :], x.dtype) + def compute_output_shape(self, input_shape): + """Return same shape as input.""" + return input_shape + def get_config(self): config = super().get_config() config.update({"length": self.length, "dim": self.dim}) @@ -73,11 +126,12 @@ def get_config(self): class ReassembleChunks(layers.Layer): - """Reassembles chunked outputs back into (B, T, H). + """ + Reassembles chunked outputs back into `(batch, time, hidden_dim)`. Args: - frames_per_chunk: int. Frames per chunk pre-conv. - postproc_chunk_len: Optional[int]. Post-processing chunk length. + frames_per_chunk (int): Frames per chunk pre-conv. + postproc_chunk_len (int, optional): Chunk length after processing. """ def __init__(self, frames_per_chunk, postproc_chunk_len=None, **kwargs): @@ -88,6 +142,18 @@ def __init__(self, frames_per_chunk, postproc_chunk_len=None, **kwargs): ) def call(self, processed_chunks, orig_spectrogram): + """ + Reassemble processed chunks into a continuous time sequence. + + Args: + processed_chunks (Tensor): Output of transformer layers + of shape (B*n_chunks, T_chunk, H). + orig_spectrogram (Tensor): Original input spectrogram + of shape (B, T, F). + + Returns: + Tensor: Reassembled tensor of shape (B, T', H). + """ B, T = ops.shape(orig_spectrogram)[0], ops.shape(orig_spectrogram)[1] n_chunks = ops.cast( ops.floor_divide( @@ -101,6 +167,10 @@ def call(self, processed_chunks, orig_spectrogram): ) return ops.reshape(processed_chunks, [B, n_chunks * T_chunk, H]) + def compute_output_shape(self, input_shape): + """Return shape compatible with a single long sequence.""" + return input_shape + def get_config(self): config = super().get_config() config.update( @@ -114,29 +184,23 @@ def get_config(self): @keras_hub_export("keras_hub.models.VoxTralBackbone") class VoxTralBackbone(Backbone): - """VoxTral audio encoder + adapter backbone. + """ + VoxTral audio encoder + adapter backbone. - This model implements the encoder portion of the VoxTral model. It takes - a log-Mel spectrogram and produces a sequence of hidden states. + This model implements the encoder portion of the VoxTral model. + It takes a log-Mel spectrogram and produces a sequence of hidden states. Args: - num_layers: int, number of transformer layers. - num_heads: int, number of attention heads. - hidden_dim: int, embedding size. - intermediate_dim: int, size of feedforward network hidden layer. - adapter_downsample: int, pooling factor after adapter dense. - dropout: float, dropout probability. - max_chunk_seconds: int, length of chunking in seconds. - sr: int, sample rate. - hop_length: int, hop length for spectrogram frames. - dtype: str or mixed_precision.Policy, dtype for layers. - - Example: - ```python - from keras_hub.models import VoxTralBackbone - model = VoxTralBackbone() - output = model(input_tensor) - ``` + num_layers (int): Number of transformer layers. + num_heads (int): Number of attention heads. + hidden_dim (int): Embedding size. + intermediate_dim (int): Size of feedforward network hidden layer. + adapter_downsample (int): Pooling factor after adapter dense. + dropout (float): Dropout probability. + max_chunk_seconds (int): Chunking length in seconds. + sr (int): Audio sample rate. + hop_length (int): Hop length for spectrogram frames. + dtype (str or mixed_precision.Policy, optional): Layer dtype. """ def __init__( @@ -153,6 +217,8 @@ def __init__( dtype=None, **kwargs, ): + """Initialize the VoxTral backbone.""" + # Store configuration self.num_layers = int(num_layers) self.num_heads = int(num_heads) self.hidden_dim = int(hidden_dim) @@ -169,7 +235,21 @@ def __init__( ) self.postconv_frames_per_chunk = self.frames_per_chunk_preconv // 2 - # Conv1D stem + # --- Mixed precision policy --- + if dtype is None: + policy = mixed_precision.global_policy() + elif isinstance(dtype, str): + policy = mixed_precision.Policy(dtype) + elif isinstance(dtype, dict): # coming from config + policy = mixed_precision.Policy(dtype["config"]["name"]) + else: + policy = dtype # already a Policy + + variable_dtype = policy.variable_dtype + compute_dtype = policy.compute_dtype + self._policy = policy # save for get_config() + + # --- Layers --- self.conv_stem_1 = layers.Conv1D( filters=self.hidden_dim, kernel_size=3, @@ -177,7 +257,7 @@ def __init__( padding="same", activation="relu", kernel_initializer=voxtral_kernel_initializer(), - dtype=dtype, + dtype=variable_dtype, name="conv_stem_1", ) self.conv_stem_2 = layers.Conv1D( @@ -187,28 +267,26 @@ def __init__( padding="same", activation="relu", kernel_initializer=voxtral_kernel_initializer(), - dtype=dtype, + dtype=variable_dtype, name="conv_stem_2", ) - # Transformer layers self.transformer_layers = [ TransformerEncoder( num_heads=self.num_heads, intermediate_dim=self.intermediate_dim, dropout=self.dropout, name=f"transformer_layer_{i}", - dtype=dtype, + dtype=variable_dtype, ) for i in range(self.num_layers) ] - # Adapter self.adapter_dense = layers.Dense( self.hidden_dim, activation="relu", kernel_initializer=voxtral_kernel_initializer(), - dtype=dtype, + dtype=variable_dtype, name="adapter_dense", ) self.adapter_pool = layers.AveragePooling1D( @@ -216,24 +294,24 @@ def __init__( strides=self.adapter_downsample, padding="valid", name="adapter_downsample", - dtype=dtype, + dtype=variable_dtype, ) - # Positional embeddings self.pos_emb = PositionalEmbedding( self.postconv_frames_per_chunk, self.hidden_dim, name="pos_emb", - dtype=dtype, + dtype=variable_dtype, ) - # Functional model + # --- Functional graph --- spectrogram_input = Input( - shape=(None, 128), dtype="float32", name="spectrogram" + shape=(None, 128), dtype=compute_dtype, name="spectrogram" ) - x = ChunkAndPad( - self.frames_per_chunk_preconv, name="chunk_and_pad", dtype="float32" + self.frames_per_chunk_preconv, + name="chunk_and_pad", + dtype=compute_dtype, )(spectrogram_input) x = self.conv_stem_1(x) x = self.conv_stem_2(x) @@ -243,11 +321,16 @@ def __init__( x = self.adapter_dense(x) x = self.adapter_pool(x) outputs = ReassembleChunks( - self.frames_per_chunk_preconv, name="reassemble_chunks", dtype=dtype + self.frames_per_chunk_preconv, + name="reassemble_chunks", + dtype=compute_dtype, )(x, spectrogram_input) super().__init__( - inputs=spectrogram_input, outputs=outputs, dtype=dtype, **kwargs + inputs=spectrogram_input, + outputs=outputs, + dtype=compute_dtype, + **kwargs, ) def get_config(self): @@ -263,6 +346,7 @@ def get_config(self): "max_chunk_seconds": self.max_chunk_seconds, "sr": self.sr, "hop_length": self.hop_length, + "dtype": self._policy.name, # store string } ) return config