-
Notifications
You must be signed in to change notification settings - Fork 271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use with static (unrolled) RNN? #13
Comments
We have not tested with any RNN examples. I'm somewhat reluctant extending it to work on RNN's given that there's a parallel effort to integrate memory saving into TensorFlow using Grappler framework, @allenlavoie may know more details |
I'd note that The plan is that Grappler's memory optimizer will (eventually) do checkpointing. Currently it will only recompute things once. Once checkpointing is implemented I'm happy to look at static RNNs if they're an issue. I'm not working on it at the moment, and it will likely be at least a quarter or two until I can get back to it. Happy to chat/make connections if someone reading this is interested in picking it up. |
Thanks for the info. We'll take a closer look at Grappler. |
@yaroslavvb have the same error with RNN |
Hi guys, thanks for your contribution. I wanted to give some feedback and request that you add a static (unrolled) RNN to your test suite. If/when I get a chance to spend more time on this, I'm happy to contribute this myself.
I tried using your code with a 2-layer LSTM RNN using dynamic_rnn and hit the same issue as here: #9
I converted my model to use static_rnn. This removes the while loop by statically unrolling for a fixed sequence length. At this point, your code was unable to automatically find articulation points. So, I tried adding manual checkpoints in a few intuitive places (at output of each layer, or at every unrolled loop iteration, or at every k unrolled loop iterations). In all cases, the memory usage was still higher than the baseline. I investigated the modified backprop graph. It seemed to be doing a lot of redundant computation and not working as described in your writing. I suspect I wasn't checkpointing correctly. A working static RNN test case would be a helpful reference.
The text was updated successfully, but these errors were encountered: