Skip to content

Commit

Permalink
Fix torch gpu ci
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Dec 10, 2024
1 parent a0f586b commit 8968ea8
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions keras/src/layers/preprocessing/image_preprocessing/mix_up.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import keras.src.random.random
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
Expand Down Expand Up @@ -54,13 +53,16 @@ def get_random_transformation(self, data, training=True, seed=None):
else:
batch_size = self.backend.shape(images)[0]

if seed is None:
seed = self._get_seed_generator(self.backend._backend)

permutation_order = self.backend.random.shuffle(
self.backend.numpy.arange(0, batch_size, dtype="int64"),
seed=self.generator,
seed=seed,
)

mix_weight = keras.src.random.random.beta(
(1,), self.alpha, self.alpha, seed=self.generator
mix_weight = self.backend.random.beta(
(1,), self.alpha, self.alpha, seed=seed
)
return {
"mix_weight": mix_weight,
Expand Down

0 comments on commit 8968ea8

Please sign in to comment.