Skip to content

Commit

Permalink
[TRAX] Fix the jnp.pad issue with mode.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 369691659
  • Loading branch information
afrozenator authored and copybara-github committed Apr 21, 2021
1 parent 65378ce commit df96ed6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion trax/layers/research/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def forward(self, x):
else:
pad_widths = [[0, 0] for _ in range(len(x.shape))]
pad_widths[1][0] = self._n_items_to_remember
x = jnp.pad(x, pad_width=pad_widths)
x = jnp.pad(x, pad_width=pad_widths, mode='constant')
return x

def init_weights_and_state(self, input_signature):
Expand Down

0 comments on commit df96ed6

Please sign in to comment.