-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement forward- and reverse mode AD in the interpreter #2186
Conversation
It is probably worth mentioning that the old version of the code can be found here: interpreter-ad-old |
Will you fix the remaining style errors or shall I? |
Honestly, I'd love to, but I'm not entirely sure that I can, within a reasonable time frame. I'm not yet comfortable enough to feel that I understand the "Haskell" way of doing things, nor even the functional way - my ugly implementation of |
You literally just have to run the |
Wow, doesn't get any easier than that ;) I'll give it a try right away |
Oh, my bad, my tired eyes missed the |
This reverts commit 2e21d68.
Thank you for the work. I have merged your implementation and created this issue to address the most significant remaining problem: #2187 I would certainly welcome further contributions, but the current implementation is operational. |
Apologies for closing the old PR; I am quite new to this.
So, as promised, I cleaned up my code a bit. That being said, more work needs doing.
Here is a list of tasks that immediately spring to mind:
vjp2
andjvp2
. They are ugly, and hugely inefficient. It seems implementing them is my kryptonite. I look forward to seeing them have the beauty they deserve ;)deriveTape
. I'm thinking this can be achieved by either (1) implementingTape
as a graph instead of a tree, or (2) assigning eachTapeOp
a unique ID using a counter inEvalM
.deriveTape
would have to initially run through theTape
, putting each uniqueTapeOp
in a lookup table, and counting the references to it. TheTape
can then be derived starting from the output. Each time a reference to aTapeOp
is encountered, the sensitivity, which is propagated to it, is added into a pool kept in the lookup table, and the number of references is decreased by one. When the number of references reaches zero, theTape
is derived. Thus, eachTape
is derived only once.Interpreter.hs
toAD.hs
. I feel like the former uses a lot of functions from the latter, making the code unnecessarily complex to read.doOp
for computations ofValuePrim
s. This would make the code for applying mathematical operations cleaner. Currently, it contains a lot of similar or duplicate code.I have also littered the code with TODOs just ripe for the taking, and added a lot of explanatory text, as you mentioned that you would use this in your teaching. I have probably added too much, so feel free to delete it.