Skip to content

Commit

Permalink
Fix RandomZoom issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Oct 2, 2023
1 parent 43be5fc commit 5536dc5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
8 changes: 6 additions & 2 deletions keras/layers/preprocessing/random_zoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,12 @@ def _get_zoom_matrix(self, zooms, image_height, image_width):
# [0 0 1]]
# where the last entry is implicit.
# zoom matrices are always float32.
x_offset = ((float(image_width) - 1.0) / 2.0) * (1.0 - zooms[:, 0:1])
y_offset = ((float(image_height) - 1.0) / 2.0) * (1.0 - zooms[:, 1:])
x_offset = ((self.backend.cast(image_width, "float32") - 1.0) / 2.0) * (
1.0 - zooms[:, 0:1]
)
y_offset = (
(self.backend.cast(image_height, "float32") - 1.0) / 2.0
) * (1.0 - zooms[:, 1:])
return self.backend.numpy.concatenate(
[
zooms[:, 0:1],
Expand Down
12 changes: 12 additions & 0 deletions keras/layers/preprocessing/random_zoom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from keras import backend
from keras import layers
from keras import models
from keras import testing


Expand Down Expand Up @@ -108,3 +109,14 @@ def test_tf_data_compatibility(self):
for output in ds.take(1):
output = output.numpy()
self.assertAllClose(expected_output, output)

def test_dynamic_shape(self):
inputs = layers.Input((None, None, 3))
outputs = layers.RandomZoom(
height_factor=(0.5, 0.5),
width_factor=(0.8, 0.8),
interpolation="nearest",
fill_mode="constant",
)(inputs)
model = models.Model(inputs, outputs)
model.predict(np.random.random((1, 6, 6, 3)))

0 comments on commit 5536dc5

Please sign in to comment.