Replies: 1 comment
-
cross-reference to your duplicate question in google/flax#2964, where I suspect you'll get a better answer. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am running the same ResNet50 model with the same weights using Pytorch and Flax, but I am seeing quite a bad performance for Flax. And also when inferencing the PyTorch model, GPU utilization is around 100% but in the FLAX model, GPU utilization is about around 35%. I checked the arrays also, they are stored in GPU.
I have warmed the jit function, and I think I have put the model and data on the device in both cases. What am I doing wrong here? Is there anything optimized way to use jit functions in building FLAX models?
Notebook: https://colab.research.google.com/drive/1b486yGovsLLuGawwhFmp6r6YUNw6Eupo?usp=sharing
Environment info
pip install flax
)pip install --upgrade jax==0.4.4 jaxlib==0.4.4+cuda11.cudnn82 -f https://storage.googleapis.com/jax releases/jax_cuda_releases.html
Beta Was this translation helpful? Give feedback.
All reactions