-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RNNs redesign #2500
base: master
Are you sure you want to change the base?
RNNs redesign #2500
Conversation
8abc593
to
aeb421b
Compare
Fully agree with updating the design to be non-mutating. There are two options we've discussed in the past:
Option 1 is outlined in this PR so I won't say anything about it. Option 2 is a more drastic redesign to make all layers (not just recurrent) non-mutating. Why?
|
I thought about Option 2. On the upside, it seems a nice intermediate spot between current Flux and Lux. The downside is that the interface would seem a bit exotic to flux and pytorch users. Moreover, it would be problematic for normalization layers. Also, we need to distinguish between normalization layers and recurrent layers.
|
6f35f2d
to
834bed3
Compare
The main benefit for keeping the state "internal" or having it be part of a unified interface like |
cf56985
to
73dae52
Compare
aa0655d
to
76cf275
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2500 +/- ##
==========================================
+ Coverage 33.46% 34.93% +1.46%
==========================================
Files 31 31
Lines 1829 1878 +49
==========================================
+ Hits 612 656 +44
- Misses 1217 1222 +5 ☔ View full report in Codecov by Sentry. |
I think this is ready. |
A complete rework of our recurrent layers, making them more similar to their pytorch counterpart.
This is in line with the proposal in #1365 and should allow to hook into the cuDNN machinery (future PR).
Hopefully, this ends the infinite source of troubles that the recurrent layers have been.
Recur
is no more. Mutating its internal state was a source of problems for AD (explicit differentiation for RNN gives wrong results #2185)RNNCell
is exported and takes care of the minimal recursion step, i.e. a single time:cell(x , h)
x
can be of sizein
orin x batch_size
h
can be of sizeout
orout x batch_size
hnew
of sizeout
orout x batch_size
RNN
instead takes in a (batched) sequence and a (batched) hidden state and returns the hidden state for the whole sequence:rnn(x, h)
x
can be of sizein x len
orin x len x batch_size
h
can be of sizeout
orout x batch_size
hnew
of sizeout x len
orout x len x batch_size
LSTM
andGRU
are similarly changed.Close #2185, close #2341, close #2258, close #1547, close #807, close #1329
Related to #1678
PR Checklist
LSTM
andGRU
reset!
cuDNN
(future PR)num_layers
argument for stacked RNNs (future PR)