Skip to content

Enabling structured inputs to call for Keras 3 #18735

Open
@areiner222

Description

@areiner222

I've heavily relied on using structured inputs for subclassed {Model, Layer}.call - will keras 3 support this?

I seem to be unable to pass a tensorflow ExtensionType or a generic dataclass (PyTreeNode in jax) hitting this value check.

I believe it should be possible to pass this kind of structured input especially with the tf_flatten / tf_unflatten utility and the jax pytree registration functionality.

TF extension type example:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras_core

class CompositeTensor(tf.experimental.ExtensionType):
    value: tf.Tensor
    meta: int
    
    def __tf_flatten__(self):
        metadata = (self.meta,)  # static config.
        components = (self.value,)  # dynamic values.
        return metadata, components

    @classmethod
    def __tf_unflatten__(cls, metadata, components):
        return cls(*metadata, *components)
    

class ModelCheck(keras_core.Model):
    def __init__(self):
        super().__init__()
        self.layer = keras_core.layers.Dense(32)

    def call(inp, training=None):
        return self.layer(inp.value)

m = ModelCheck()

inp = CompositeTensor(value=tf.random.uniform((10, 64)), meta=3)
print([type(v) for v in tf.nest.flatten(inp)])
out = m(inp)

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions