-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gaussian blur cpu performance #46
Comments
Hi @mgoulao, thanks for reaching out! Yeah indeed I've also tested this and is not performing quite well on CPU. Transferring the image to CPU only helps a little, it's a gain of few us over a several ms operation. This is not technically intended, the goal we try to achieve with PIX is to have implementations that perform well on TPUs/GPUs, taking what we get as a result of this when running on CPUs. This doesn't mean, of course, that we don't want/have to improve CPU implementation as well 😄 Feel free to submit a PR with any optimisation for CPU! |
Hi @mgoulao I made a JAX package for stencil computation that can be used to calculate the gaussian blur. I checked the performance of Hope this helps # !pip install dm_pix
# !pip install kernex
import jax
import jax.numpy as jnp
import kernex as kex
import dm_pix
import numpy.testing as npt
def gaussian_blur(image, sigma, kernel_size):
x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
w = jnp.outer(w, w)
w = w / w.sum()
@kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
def conv(x):
return jnp.sum(x * w)
return conv(image)
sigma = 1.
kernel_size=5
gaussian_blur_pix = jax.jit(lambda x: dm_pix.gaussian_blur(x,sigma, kernel_size))
gaussian_blur_kex = jax.jit(lambda x: gaussian_blur(x, sigma, kernel_size))
x = jax.random.uniform(jax.random.PRNGKey(0), (512,512))
xx = jnp.expand_dims(x, axis=2)
npt.assert_allclose(gaussian_blur_pix(xx)[:,:,0], gaussian_blur_kex(x), atol=1e-5)
# warm up
gaussian_blur_pix(xx)
gaussian_blur_kex(x)
%timeit gaussian_blur_pix(xx).block_until_ready()
%timeit gaussian_blur_kex(x).block_until_ready() 111 ms ± 40 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
11.1 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) On colab GPU its seems that 324 µs ± 111 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
200 µs ± 4.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) |
Thanks for reporting this as well! I'm a bit short of time at the moment, and for the whole October I'm afraid. I'll try to have a look asap, or later beginning of November. In the meantime, if you come up with a better implementation that works as well on CPU without extra dependencies, feel free to submit a PR! 🚀 |
Noted thanks, |
Hey, I implemented you can find testing and benchmarking against the depthwise-based implementation. # average time ratio pix/kex for 3x3 kernel
# (64, 64, 1): 12.17
# (128, 128, 1): 14.70
# (256, 256, 1): 17.38
# (512, 512, 1): 16.37
# (64, 64, 32): 62.64
# (128, 128, 32): 44.88
# (256, 256, 32): 36.19
# (512, 512, 32): 36.60
# (64, 64, 64): 42.34
# (128, 128, 64): 80.46
# (256, 256, 64): 57.42
# (512, 512, 64): 54.94 for GPU, the speed-up ratio is # average time ratio pix/kex for 3x3 kernel
# (64, 64, 1): 1.76
# (128, 128, 1): 1.87
# (256, 256, 1): 1.82
# (512, 512, 1): 1.98
# (64, 64, 32): 1.81
# (128, 128, 32): 2.67
# (256, 256, 32): 2.72
# (512, 512, 32): 5.24
# (64, 64, 64): 2.96
# (128, 128, 64): 1.78
# (256, 256, 64): 3.22
# (512, 512, 64): 8.81 Let me know if it's suitable for a PR Best. |
Thanks @ASEM000, I'll have a look at it as soon as I can, unfortunately that will probably be end of month 😭 |
I just skimmed through the code, so without checking the implementation details. |
Yes, you are right; sorry for the typo. |
That's ok. Skimming through, looks good, but please let's resume this EOM so I have more time to look into the code and give proper advices for submitting a PR 😄 |
I'm finally back. I'll try to look into this asap! |
Hello, Additionally, I implemented a Gaussian filter based on FFT depthwise convolution, which should be faster for large kernels. |
Hey @ASEM000, I have not forgotten about this 😄 |
I have been doing some experiments with PIX since it allows computing image augmentations in the GPU in contrast to torchvision which computes in the CPU and requires multiple workers to avoid bottlenecks. When performing some very simple
timeit
examples I observed a very high time when performing a gaussian blur in the CPU. I created a simple Colab notebook to demonstrate these experiments. I even tested transferring the image to CPU before performing the blur but it doesn't seem to make any difference. I was wondering if this is intended and I should not rely on CPU computations at all or if something is yet to be optimized for CPU computation.The text was updated successfully, but these errors were encountered: