Skip to content

Commit

Permalink
Make apply_fn control flow jittable
Browse files Browse the repository at this point in the history
  • Loading branch information
EhsanEI committed Jul 18, 2024
1 parent 3c65486 commit 0f111c6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 16 deletions.
16 changes: 4 additions & 12 deletions jax_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,10 @@ def project(self, x, train=True):
x = layer(x) if not isinstance(layer, nn.Dropout) else layer(x, deterministic=not train)
return x

def __call__(self, src=None, tgt=None, src_mask=None, tgt_mask=None,
train=True, encoder_output=None, encode=True, decode=True):
assert encode or decode
if not decode:
out = self.encode(src, src_mask, train=train)
elif not encode:
decoder_output = self.decode(encoder_output, tgt, src_mask, tgt_mask, train=train)
out = self.project(decoder_output, train=train)
else:
encoder_output = self.encode(src, src_mask, train=train)
decoder_output = self.decode(encoder_output, tgt, src_mask, tgt_mask, train=train)
out = self.project(decoder_output, train=train)
def __call__(self, src=None, tgt=None, src_mask=None, tgt_mask=None, train=True):
encoder_output = self.encode(src, src_mask, train=train)
decoder_output = self.decode(encoder_output, tgt, src_mask, tgt_mask, train=train)
out = self.project(decoder_output, train=train)
return out


Expand Down
9 changes: 5 additions & 4 deletions translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
main_rng = random.PRNGKey(config['seed'])


def greedy_decode(state, rng, src_sentence, initial_input=""):
def greedy_decode(state, transformer, rng, src_sentence, initial_input=""):
src_item = pad_to_seq_len(text_transform['src'](src_sentence))
src = jnp.array([src_item])

Expand All @@ -28,10 +28,11 @@ def greedy_decode(state, rng, src_sentence, initial_input=""):
src_mask, tgt_mask = generate_mask(src, tgt_input)

rng, dropout_apply_rng = random.split(rng)

encoder_output = state.apply_fn(
{'params': state.params},
src=src, src_mask=src_mask,
train=False, decode=False,
train=False, method=transformer.encode,
rngs={'dropout': dropout_apply_rng})

translation = []
Expand All @@ -42,7 +43,7 @@ def greedy_decode(state, rng, src_sentence, initial_input=""):
tgt_output = state.apply_fn(
{'params': state.params},
encoder_output=encoder_output, tgt=tgt_input, src_mask=src_mask, tgt_mask=tgt_mask,
train=False, encode=False,
train=False, method=transformer.decode,
rngs={'dropout': dropout_apply_rng})
tgt_output = tgt_output.argmax(axis=-1)

Expand All @@ -64,4 +65,4 @@ def greedy_decode(state, rng, src_sentence, initial_input=""):
src_sentence = input("Enter German text to translate or leave blank to exit:\n")
if not src_sentence:
break
print('Translation:', greedy_decode(state, rng, src_sentence), '\n')
print('Translation:', greedy_decode(state, transformer, rng, src_sentence), '\n')

0 comments on commit 0f111c6

Please sign in to comment.