Skip to content

Commit

Permalink
Replace asserts with torch.testing asserts in test_linear_bandits.py
Browse files Browse the repository at this point in the history
Summary: Regular asserts don't print an informative message when they fail. Replacing with `torch.testing` asserts, which print out useful information for debugging

Differential Revision: D56588249
  • Loading branch information
Alex Nikulkov authored and facebook-github-bot committed Apr 25, 2024
1 parent e38f01b commit 316e9e4
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions test/unit/with_pytorch/test_linear_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import unittest

import torch
import torch.testing as tt
from pearl.neural_networks.contextual_bandit.linear_regression import LinearRegression
from pearl.policy_learners.contextual_bandits.linear_bandit import LinearBandit
from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import ( # noqa: E501
Expand Down Expand Up @@ -52,24 +53,20 @@ def setUp(self) -> None:
def test_learn(self) -> None:
batch = self.batch
# a single input
self.assertTrue(
torch.allclose(
self.policy_learner.model(
torch.cat([batch.state[0], batch.action[0]]).unsqueeze(0),
),
batch.reward[0:1],
atol=1e-4,
)
tt.assert_close(
self.policy_learner.model(
torch.cat([batch.state[0], batch.action[0]]).unsqueeze(0),
),
batch.reward[0:1],
atol=1e-3,
rtol=0.0,
)
# a batch input
self.assertTrue(
torch.allclose(
self.policy_learner.model(
torch.cat([batch.state, batch.action], dim=1)
),
batch.reward,
atol=1e-4,
)
tt.assert_close(
self.policy_learner.model(torch.cat([batch.state, batch.action], dim=1)),
batch.reward,
atol=1e-3,
rtol=0.0,
)

def test_linear_ucb_scores(self) -> None:
Expand Down Expand Up @@ -167,12 +164,11 @@ def test_linear_ucb_sigma(self) -> None:

# the 2nd arm's sigma is sqrt(10) times 1st arm's sigma
sigma_ratio = (sigma[-1] / sigma[0]).clone().detach()
self.assertTrue(
torch.allclose(
sigma_ratio,
torch.tensor(10.0**0.5), # the 1st arm occured 10 times than 2nd arm
rtol=0.01,
)
tt.assert_close(
sigma_ratio,
torch.tensor(10.0**0.5), # the 1st arm occured 10 times than 2nd arm
rtol=0.01,
atol=0.0,
)

def test_linear_thompson_sampling_act(self) -> None:
Expand Down

0 comments on commit 316e9e4

Please sign in to comment.