diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 418f2982777..804e9323a0f 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -45,7 +45,7 @@ class RandomGrayscale(BaseImagePreprocessingLayer): will have the same value. """ - def __init__(self, factor=0.5, data_format=None, **kwargs): + def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs): super().__init__(**kwargs) if factor < 0 or factor > 1: raise ValueError( @@ -54,7 +54,8 @@ def __init__(self, factor=0.5, data_format=None, **kwargs): ) self.factor = factor self.data_format = backend.standardize_data_format(data_format) - self.generator = self.backend.random.SeedGenerator() + self.seed = seed + self.generator = self.backend.random.SeedGenerator(seed) def get_random_transformation(self, images, training=True, seed=None): if seed is None: