From 316e9e452e50b9468e67a3483d243631a95cc0e9 Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Thu, 25 Apr 2024 13:42:57 -0700 Subject: [PATCH] Replace asserts with torch.testing asserts in test_linear_bandits.py Summary: Regular asserts don't print an informative message when they fail. Replacing with `torch.testing` asserts, which print out useful information for debugging Differential Revision: D56588249 --- test/unit/with_pytorch/test_linear_bandits.py | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/test/unit/with_pytorch/test_linear_bandits.py b/test/unit/with_pytorch/test_linear_bandits.py index b02af327..957b7840 100644 --- a/test/unit/with_pytorch/test_linear_bandits.py +++ b/test/unit/with_pytorch/test_linear_bandits.py @@ -11,6 +11,7 @@ import unittest import torch +import torch.testing as tt from pearl.neural_networks.contextual_bandit.linear_regression import LinearRegression from pearl.policy_learners.contextual_bandits.linear_bandit import LinearBandit from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import ( # noqa: E501 @@ -52,24 +53,20 @@ def setUp(self) -> None: def test_learn(self) -> None: batch = self.batch # a single input - self.assertTrue( - torch.allclose( - self.policy_learner.model( - torch.cat([batch.state[0], batch.action[0]]).unsqueeze(0), - ), - batch.reward[0:1], - atol=1e-4, - ) + tt.assert_close( + self.policy_learner.model( + torch.cat([batch.state[0], batch.action[0]]).unsqueeze(0), + ), + batch.reward[0:1], + atol=1e-3, + rtol=0.0, ) # a batch input - self.assertTrue( - torch.allclose( - self.policy_learner.model( - torch.cat([batch.state, batch.action], dim=1) - ), - batch.reward, - atol=1e-4, - ) + tt.assert_close( + self.policy_learner.model(torch.cat([batch.state, batch.action], dim=1)), + batch.reward, + atol=1e-3, + rtol=0.0, ) def test_linear_ucb_scores(self) -> None: @@ -167,12 +164,11 @@ def test_linear_ucb_sigma(self) -> None: # the 2nd arm's sigma is sqrt(10) times 1st arm's sigma sigma_ratio = (sigma[-1] / sigma[0]).clone().detach() - self.assertTrue( - torch.allclose( - sigma_ratio, - torch.tensor(10.0**0.5), # the 1st arm occured 10 times than 2nd arm - rtol=0.01, - ) + tt.assert_close( + sigma_ratio, + torch.tensor(10.0**0.5), # the 1st arm occured 10 times than 2nd arm + rtol=0.01, + atol=0.0, ) def test_linear_thompson_sampling_act(self) -> None: