Skip to content

Commit

Permalink
remove jax static argnums in bilateral blur
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 10, 2024
1 parent bc2509c commit 80f2cc2
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions serket/_src/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def fft_elastic_transform_2d(
return map_coordinates(image, (ny, nx), order=1, mode="nearest").reshape(r, c)


@ft.partial(jax.jit, inline=True, static_argnums=1)
def bilateral_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
Expand Down Expand Up @@ -663,7 +662,6 @@ def bilateral_blur_2d(array):
return bilateral_blur_2d(image)


@ft.partial(jax.jit, inline=True, static_argnums=2)
def joint_bilateral_blur_2d(
image: HWArray,
guidance: HWArray,
Expand Down

0 comments on commit 80f2cc2

Please sign in to comment.