Skip to content

Commit

Permalink
fix(torch-frontend): Fixed torch.lstm to get correct attention output…
Browse files Browse the repository at this point in the history
…s in the packed case.
  • Loading branch information
AnnaTz committed Oct 18, 2023
1 parent 92adae0 commit 95a7204
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def _generic_lstm(
h_outs = h_out if num_layers == 1 else ivy.concat(h_outs, axis=0)
c_outs = c_out if num_layers == 1 else ivy.concat(c_outs, axis=0)

if batch_sizes is not None:
output = _pack_padded_sequence(output, batch_sizes)[0]

return output, h_outs, c_outs


Expand Down Expand Up @@ -160,10 +163,9 @@ def _lstm_cell(
output = ivy.concat(ht_list, axis=0)
else:
ct_list = ivy.concat(ct_list, axis=0)
ht_list = ivy.concat(ht_list, axis=0)
output = ht_list = ivy.concat(ht_list, axis=0)
c = _extract_states(ct_list, batch_sizes)
h = _extract_states(ht_list, batch_sizes)
output = _pack_padded_sequence(ht_list, batch_sizes)[0]
return output, (h, c)


Expand Down

0 comments on commit 95a7204

Please sign in to comment.