Open
Description
I've heavily relied on using structured inputs for subclassed {Model, Layer}.call - will keras 3 support this?
I seem to be unable to pass a tensorflow ExtensionType or a generic dataclass (PyTreeNode in jax) hitting this value check.
I believe it should be possible to pass this kind of structured input especially with the tf_flatten / tf_unflatten utility and the jax pytree registration functionality.
TF extension type example:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras_core
class CompositeTensor(tf.experimental.ExtensionType):
value: tf.Tensor
meta: int
def __tf_flatten__(self):
metadata = (self.meta,) # static config.
components = (self.value,) # dynamic values.
return metadata, components
@classmethod
def __tf_unflatten__(cls, metadata, components):
return cls(*metadata, *components)
class ModelCheck(keras_core.Model):
def __init__(self):
super().__init__()
self.layer = keras_core.layers.Dense(32)
def call(inp, training=None):
return self.layer(inp.value)
m = ModelCheck()
inp = CompositeTensor(value=tf.random.uniform((10, 64)), meta=3)
print([type(v) for v in tf.nest.flatten(inp)])
out = m(inp)