Skip to content

Commit

Permalink
Fixed failing GPU test.
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelalonsojr committed Oct 4, 2024
1 parent 9d34978 commit ee94d91
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_predict_minimum_training():
argmin = argmin.squeeze()
argmin = argmin.detach()
sliced_oh = onehots[:, : num + 1]
inp = torch.cat([inp, sliced_oh], dim=2)
inp = torch.cat([inp, sliced_oh.to(torch.get_default_device())], dim=2)

embeddings = entity_embedding(inp, inp)
masks = get_zero_entities_mask([inp])
Expand Down

0 comments on commit ee94d91

Please sign in to comment.