From d23018445bf52d9caa34e9ea69e843d23c1fbdd1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 13:25:58 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/ppo/utils_atari.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index ab39c102106..fa9d4bb053e 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -115,7 +115,7 @@ def make_ppo_modules_pixels(proof_environment, device): strides=[4, 2, 1], device=device, ) - common_cnn_output = common_cnn(torch.ones(input_shape)) + common_cnn_output = common_cnn(torch.ones(input_shape, device=device)) common_mlp = MLP( in_features=common_cnn_output.shape[-1], activation_class=torch.nn.ReLU,