diff --git a/ivy/functional/frontends/torch/nn/functional/layer_functions.py b/ivy/functional/frontends/torch/nn/functional/layer_functions.py index 30bf2983e36d7..f40cc2e7959cc 100644 --- a/ivy/functional/frontends/torch/nn/functional/layer_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/layer_functions.py @@ -107,6 +107,9 @@ def _generic_lstm( h_outs = h_out if num_layers == 1 else ivy.concat(h_outs, axis=0) c_outs = c_out if num_layers == 1 else ivy.concat(c_outs, axis=0) + if batch_sizes is not None: + output = _pack_padded_sequence(output, batch_sizes)[0] + return output, h_outs, c_outs @@ -160,10 +163,9 @@ def _lstm_cell( output = ivy.concat(ht_list, axis=0) else: ct_list = ivy.concat(ct_list, axis=0) - ht_list = ivy.concat(ht_list, axis=0) + output = ht_list = ivy.concat(ht_list, axis=0) c = _extract_states(ct_list, batch_sizes) h = _extract_states(ht_list, batch_sizes) - output = _pack_padded_sequence(ht_list, batch_sizes)[0] return output, (h, c)