diff --git a/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py b/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py index f7344a647b8..3a881ed5e7a 100644 --- a/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py +++ b/ml-agents/mlagents/trainers/tests/torch_entities/test_attention.py @@ -217,7 +217,7 @@ def test_predict_minimum_training(): argmin = argmin.squeeze() argmin = argmin.detach() sliced_oh = onehots[:, : num + 1] - inp = torch.cat([inp, sliced_oh], dim=2) + inp = torch.cat([inp, sliced_oh.to(torch.get_default_device())], dim=2) embeddings = entity_embedding(inp, inp) masks = get_zero_entities_mask([inp])