Skip to content

Commit

Permalink
Make remat optional in scan as it's not in all versions of jax, add t…
Browse files Browse the repository at this point in the history
…ests.

PiperOrigin-RevId: 290262833
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Jan 17, 2020
1 parent 4bfc87c commit 899cba8
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 5 deletions.
2 changes: 2 additions & 0 deletions trax/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ def new_rngs(self, n):
"""
if n < 1:
raise ValueError('n must be > 0; received value: {}'.format(n))
if self._rng is None:
self._rng = math.random.get_prng(0)
rngs = math.random.split(self._rng, n + 1)
self._rng = rngs[0]
return tuple(rngs[1:])
Expand Down
5 changes: 3 additions & 2 deletions trax/layers/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,12 @@ def add(x)
Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6
"""

def __init__(self, layer, axis=0, n_carry=1):
def __init__(self, layer, axis=0, n_carry=1, remat=False):
super(Scan, self).__init__(n_in=layer.n_in, n_out=layer.n_out)
self._sublayers = [layer]
self._n_carry = n_carry
self._axis = axis
self._remat = remat

@property
def sublayer(self):
Expand All @@ -395,7 +396,7 @@ def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name
else:
xs, init = inputs, ([], state)
ys, (carry, new_state) = math.scan(scannable_fn, xs, init,
axis=self._axis)
axis=self._axis, remat=self._remat)
res = ys + carry if n_carry > 0 else ys
return res, new_state # Put outputs and carry back on stack.

Expand Down
8 changes: 6 additions & 2 deletions trax/math/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def jax_avg_pool(x, pool_size, strides, padding):
pool_size, strides=strides, padding=padding)


def _jax_scan(f, xs, init_value, axis=0):
def _jax_scan(f, xs, init_value, axis=0, remat=False):
"""Scans the f over the given axis of xs.
In pseudo-python, the scan function would look as follows:
Expand All @@ -111,6 +111,7 @@ def scan(f, xs, init_value, axis):
xs: tensor, x will be xs slices on axis
init_value: tensor, initial value of the carry-over
axis: int, the axis on which to slice xs
remat: whether to re-materialize f
Returns:
A pair (ys, last_value) as described above.
Expand All @@ -125,7 +126,10 @@ def swapaxes(x):
def transposed_f(c, x):
y, d = f(x, c)
return d, y
last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs)
if remat:
last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs)
else:
last_value, ys = lax.scan(transposed_f, init_value, xs)
if axis != 0:
ys = nested_map(swapaxes, ys)
return ys, last_value
Expand Down
2 changes: 1 addition & 1 deletion trax/models/reformer/reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def reshape_to_chunks(x):
return [
tl.Dup(), # Just to have shape for later after scan.
tl.Fn(reshape_to_chunks, n_out=1),
tl.Scan(tl.Serial(ff), axis=0, n_carry=0),
tl.Scan(tl.Serial(ff), axis=0, n_carry=0, remat=True),
tl.Fn(lambda x, y: np.reshape(x, y.shape))
]

Expand Down
27 changes: 27 additions & 0 deletions trax/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,30 @@ def MultiRNNCell():
tl.Dense(vocab_size),
tl.LogSoftmax()
)


def GRULM(vocab_size=256,
d_model=512,
n_layers=1,
mode='train'):
"""Returns an GRU language model.
The input to the model is a tensor of tokens (ints).
Args:
vocab_size: int: vocab size
d_model: int: depth of embedding (n_units in the RNN cell)
n_layers: int: number of RNN layers
mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference
Returns:
An RNN language model as a layer that maps from a tensor of tokens
to activations over a vocab set.
"""
return tl.Serial(
tl.ShiftRight(mode=mode),
tl.Embedding(d_model, vocab_size),
[tl.GRU(d_model) for _ in range(n_layers)],
tl.Dense(vocab_size),
tl.LogSoftmax()
)
8 changes: 8 additions & 0 deletions trax/models/rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def test_rnnlm_forward_shape(self):
final_shape = tl.check_shape_agreement(model, input_signature)
self.assertEqual((3, 28, 20), final_shape)

def test_grulm_forward_shape(self):
"""Runs the GRU LM forward and checks output shape."""
input_signature = ShapeDtype((3, 28), dtype=math.numpy.int32)
model = rnn.GRULM(vocab_size=20, d_model=16)
model.init(input_signature)
final_shape = tl.check_shape_agreement(model, input_signature)
self.assertEqual((3, 28, 20), final_shape)


if __name__ == '__main__':
absltest.main()

0 comments on commit 899cba8

Please sign in to comment.