diff --git a/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py b/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py index fb8efa1894ae0..9f3435ce74582 100644 --- a/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py +++ b/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py @@ -45,7 +45,7 @@ def test(self, x_all, subgraph_loader): for batch_size, n_id, adj in subgraph_loader: edge_index, _, size = adj - x = x_all[n_id] + x = x_all[n_id.to(x_all.device)].to(train.torch.get_device()) x_target = x[: size[1]] x = self.convs[i]((x, x_target), edge_index) if i != self.num_layers - 1: