From df96ed6701a6abaf570287641100347ed3c195e0 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Wed, 21 Apr 2021 11:02:16 -0700 Subject: [PATCH] [TRAX] Fix the `jnp.pad` issue with mode. PiperOrigin-RevId: 369691659 --- trax/layers/research/sparsity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trax/layers/research/sparsity.py b/trax/layers/research/sparsity.py index 5bf322654..9c719a601 100644 --- a/trax/layers/research/sparsity.py +++ b/trax/layers/research/sparsity.py @@ -243,7 +243,7 @@ def forward(self, x): else: pad_widths = [[0, 0] for _ in range(len(x.shape))] pad_widths[1][0] = self._n_items_to_remember - x = jnp.pad(x, pad_width=pad_widths) + x = jnp.pad(x, pad_width=pad_widths, mode='constant') return x def init_weights_and_state(self, input_signature):