Skip to content

Commit

Permalink
Update image.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 8, 2023
1 parent f115eb7 commit 99361fa
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions serket/nn/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,30 +782,27 @@ def spatial_ndim(self) -> int:
return 2


def perspective_transform(coeffs, points):
"""Applies a perspective transform to a set of points."""
def perspective_transform(image: jax.Array, coeffs: jax.Array) -> jax.Array:
"""Apply a perspective transform to an image."""

_, rows, cols = image.shape
y, x = jnp.meshgrid(jnp.arange(rows), jnp.arange(cols), indexing="ij")
a, b, c, d, e, f, g, h = coeffs
x, y = points[..., 0], points[..., 1]
w = g * x + h * y + 1.0
x_prime = (a * x + b * y + c) / w
y_prime = (d * x + e * y + f) / w
return jnp.stack([x_prime, y_prime], axis=-1)


def transform_image(image, coeffs):
_, rows, cols = image.shape
y, x = jnp.meshgrid(jnp.arange(rows), jnp.arange(cols), indexing="ij")
primes = perspective_transform(coeffs, jnp.stack([x, y], axis=-1))
coordinates = [primes[..., 1].ravel(), primes[..., 0].ravel()]
coords = [y_prime.ravel(), x_prime.ravel()]

def transform_channel(image):
return map_coordinates(image, coordinates, order=1).reshape(rows, cols)
return map_coordinates(image, coords, order=1).reshape(rows, cols)

return jax.vmap(transform_channel)(image)


def random_perspective(
image: jax.Array, key: jax.random.KeyArray, scale: float = 1.0
image: jax.Array,
key: jax.random.KeyArray,
scale: float = 1.0,
) -> jax.Array:
"""Applies a random perspective transform to a channel-first image"""
_, __, ___ = image.shape
Expand All @@ -814,7 +811,7 @@ def random_perspective(
c = f = 0.0 # no translation
g, h = jr.uniform(key, shape=(2,), minval=-1e-4, maxval=1e-4) * scale
coeffs = jnp.array([a, b, c, d, e, f, g, h])
return transform_image(image, coeffs)
return perspective_transform(image, coeffs)


class RandomPerspective2D(sk.TreeClass):
Expand Down

0 comments on commit 99361fa

Please sign in to comment.