diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 9932c8ba8b7..629d83a6dd3 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -50,13 +50,13 @@ ) # Anything from 2.5, incl. nightlies, allows for fullgraph -@pytest.fixture(scope="module", autouse=True) -def set_default_device(): - cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - torch.set_default_device(device) - yield - torch.set_default_device(cur_device) +# @pytest.fixture(scope="module", autouse=True) +# def set_default_device(): +# cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# torch.set_default_device(device) +# yield +# torch.set_default_device(cur_device) class setup_value_fn: @@ -173,7 +173,14 @@ def test_dqn_speed( ): if compile: torch._dynamo.reset_code_caches() - net = MLP(in_features=n_obs, out_features=n_act, depth=depth, num_cells=ncells) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + net = MLP( + in_features=n_obs, + out_features=n_act, + depth=depth, + num_cells=ncells, + device=device, + ) action_space = "one-hot" mod = QValueActor(net, in_keys=["obs"], action_space=action_space) loss = DQNLoss(value_network=mod, action_space=action_space) @@ -188,6 +195,7 @@ def test_dqn_speed( }, }, [batch], + device=device, ) loss(td) @@ -220,23 +228,27 @@ def test_ddpg_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -251,6 +263,7 @@ def test_ddpg_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor_head = Mod(actor, in_keys=["hidden"], out_keys=["action"]) @@ -291,23 +304,27 @@ def test_sac_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -322,6 +339,7 @@ def test_sac_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -374,23 +392,27 @@ def test_redq_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -405,6 +427,7 @@ def test_redq_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -460,23 +483,27 @@ def test_redq_deprec_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -491,6 +518,7 @@ def test_redq_deprec_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -544,23 +572,27 @@ def test_td3_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -575,6 +607,7 @@ def test_td3_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -633,23 +666,27 @@ def test_cql_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -664,6 +701,7 @@ def test_cql_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -724,23 +762,27 @@ def test_a2c_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common_net = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value_net = MLP( in_features=n_hidden, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch, T] td = TensorDict( @@ -757,6 +799,7 @@ def test_a2c_speed( }, batch, names=[None, "time"], + device=device, ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -775,7 +818,9 @@ def test_a2c_speed( critic(td.clone()) loss = A2CLoss(actor_network=actor, critic_network=critic) - advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) + advantage = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device + ) advantage(td) loss(td) @@ -816,23 +861,27 @@ def test_ppo_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common_net = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value_net = MLP( in_features=n_hidden, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch, T] td = TensorDict( @@ -849,6 +898,7 @@ def test_ppo_speed( }, batch, names=[None, "time"], + device=device, ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -867,7 +917,9 @@ def test_ppo_speed( critic(td.clone()) loss = ClipPPOLoss(actor_network=actor, critic_network=critic) - advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) + advantage = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device + ) advantage(td) loss(td) @@ -908,23 +960,27 @@ def test_reinforce_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common_net = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value_net = MLP( in_features=n_hidden, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch, T] td = TensorDict( @@ -941,6 +997,7 @@ def test_reinforce_speed( }, batch, names=[None, "time"], + device=device, ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -959,7 +1016,9 @@ def test_reinforce_speed( critic(td.clone()) loss = ReinforceLoss(actor_network=actor, critic_network=critic) - advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) + advantage = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device + ) advantage(td) loss(td) @@ -1000,29 +1059,34 @@ def test_iql_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common_net = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value_net = MLP( in_features=n_hidden, num_cells=ncells, depth=2, out_features=1, + device=device, ) qvalue_net = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch, T] td = TensorDict( @@ -1039,6 +1103,7 @@ def test_iql_speed( }, batch, names=[None, "time"], + device=device, ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -1087,4 +1152,4 @@ def loss_and_bw(td): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main([__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + unknown)