You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using the JAX backend with version 0.4.28, I'm encountering an "Array has been deleted" error during training when a convolutional layer is placed within a RematScope. The deleted array is the convolutional kernel (weights) of the layer.
This issue seems to be specific to convolutional layers, as using a dense layer within the RematScope works without errors. The training function is already jit-compiled.
When using the JAX backend with version 0.4.28, I'm encountering an "Array has been deleted" error during training when a convolutional layer is placed within a RematScope. The deleted array is the convolutional kernel (weights) of the layer.
This issue seems to be specific to convolutional layers, as using a dense layer within the RematScope works without errors. The training function is already jit-compiled.
JAX version: 0.4.28
Link to failed invocation: https://btx.cloud.google.com/invocations/d2e33ea2-b22c-44ad-8eb8-ca105455c926/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Fjax%2Fpresubmit/log
The text was updated successfully, but these errors were encountered: