The memory consumption is almost same #15231
Unanswered
DaShenZi721
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have recently been testing the memory usage of several baselines for text tasks on the Long-Range Arena benchmark. However, I have found that regardless of which baseline I use or the batch size, the memory usage is almost the same(around 160MB), as shown in the figure below.
The code can be found:
https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/text_classification/train.py
I am using the jax-smi, which actually enables an additional thread that runs the
jax.profiler.save_device_memory_profile()
command every second. I added the following two lines of code in the main function.My jax environment is as follows
Beta Was this translation helpful? Give feedback.
All reactions