-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Tests and Cleanup #873
Conversation
I think if we are removing |
I think JuliaDiff/ChainRules.jl#713 is related. |
fe5f667
to
8e4cd87
Compare
2aa4dd2
to
3c4fcb3
Compare
190aa0a
to
e521e2c
Compare
f2bd8cd
to
c7a3158
Compare
c7a3158
to
28f1644
Compare
Fixes #781 |
Fixing the AD problem closes #814 |
Supporting only Lux makes #795 irrelevant and closes it. |
Now that we are breaking some APIs, should we consider #790 ? |
ba86fb6
to
a072fd8
Compare
a072fd8
to
3a1ffba
Compare
16034dd
to
531132a
Compare
683ac1b
to
ec126f4
Compare
ec126f4
to
13e4b83
Compare
aa6fe16
to
299ca49
Compare
What's this part? |
mz, pb_f = Zygote.pullback(model, z, p) | ||
e = CRC.@ignore_derivatives randn!(similar(mz)) | ||
eJ = first(pb_f(e)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should really change this to forward mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added ADTypes
here so we can switch the internal AD using ad = AutoForwardDiff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using JacVec
or VecJac
from SparseDiffTools
, we can pass ad
to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah what I mean is that it should be using JacVec (and almost never VecJac). It's e'Je, and so it should essentially always calculate e' * (Je) and not (e'J) * e which is what it does now (along with the PyTorch code)
The DAEs didnot have any tests running even before, and I couldn't get AD to work currently |
Which form? The mass matrix? |
MM form (NeuralODEMM) works. the NeuralDAE form |
Fixes #781 Fixes #795 Fixes #639 Fixes #638, Fixes #707 Fixes #733