diff --git a/tensorflow_graphics/projects/gan/architectures_style_gan_v2.py b/tensorflow_graphics/projects/gan/architectures_style_gan_v2.py index 4817d40a7..4eed2ff8e 100644 --- a/tensorflow_graphics/projects/gan/architectures_style_gan_v2.py +++ b/tensorflow_graphics/projects/gan/architectures_style_gan_v2.py @@ -242,6 +242,7 @@ def create_discriminator( kernel_initializer: Optional[_KerasInitializer] = None, use_fan_in_scaled_kernels: bool = True, use_antialiased_bilinear_downsampling: bool = False, + num_channels: int = 3, name: str = 'style_gan_v2_discriminator'): """Creates a Keras model for the discriminator architecture. @@ -262,6 +263,7 @@ def create_discriminator( ani-aliased bilinear downsampling with a [1, 3, 3, 1] tent kernel. If false standard bilinear downsampling, i.e. average pooling is used ([1, 1] tent kernel). + num_channels: The number of channels of the input tensor. name: The name of the Keras model. Returns: @@ -271,7 +273,7 @@ def create_discriminator( kernel_initializer = tf.keras.initializers.TruncatedNormal( mean=0.0, stddev=1.0) - input_tensor = tf.keras.Input(shape=(None, None, 3)) + input_tensor = tf.keras.Input(shape=(None, None, num_channels)) tensor = architectures_progressive_gan.from_rgb( input_tensor=input_tensor, use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,