JAX/FLAX scales up worse than Tensorflow #16473
Unanswered
giorgiofranceschelli
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I started learning JAX/FLAX with a very simple VAE model working on MNIST, and when I saw its better performances vs Tensorflow I decided to move my current project to JAX/FLAX. However, with bigger and more complex architectures, I experienced worse performances than the original TF implementation. So I went back to the VAE-MNIST and checked if it scales up correctly, but it des not seem to be the case.
In particular, with this implementation:
and by varying image dimension, latent size, batch size or convolutional filters I obtained the following performances:
with z_dim = 64, bs = 64, image_dim = (28, 28, 1), conv_filters = 32:
with z_dim = 64, bs = 64, image_dim = (64, 64, 1), conv_filters = 32:
with z_dim = 256, bs = 64, image_dim = (28, 28, 1), conv_filters = 32:
with z_dim = 64, bs = 256, image_dim = (28, 28, 1), conv_filters = 32:
with z_dim = 64, bs = 64, image_dim = (28, 28, 1), conv_filters = 128:
I runned everything on GPU with Google Colab.
As you can see, JAX/FLAX is way faster for the base experiment, but slows down with respect to Tensorflow if z_dim, batch_size, image_dim or conv_filters increase.
What am I doing wrong? Any help would be appreciated.
Beta Was this translation helpful? Give feedback.
All reactions