Skip to content

Commit

Permalink
TS bug fix
Browse files Browse the repository at this point in the history
Summary:
TS bug fix: numerical errors leads to deltaA to be non symmetric (order small <1e-8). When summing this it may result with a non-symmetric matrix A which leads to an error in the type check: when sampling from a Gaussian distribution the covariance matrix is required to be both PD and symmetric.

The offered solution is to explicitly symmetrize it, hence guarantee deltaA is symmetric.

Reviewed By: danielrjiang

Differential Revision: D56936024

fbshipit-source-id: 8219ca1e1a1577b0a241c0a2d7848ab02f427c2c
  • Loading branch information
Yonathan Efroni authored and facebook-github-bot committed May 3, 2024
1 parent d1d4dad commit 84b8772
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def learn_batch(
x, y, weight = self._validate_train_inputs(x, y, weight)

delta_A = torch.matmul(x.t(), x * weight)
delta_A = (delta_A + delta_A.t()) / 2 # symmetrize to avoid numerical errors
delta_b = torch.matmul(x.t(), y * weight).squeeze(-1)
delta_sum_weight = weight.sum()

Expand Down
14 changes: 6 additions & 8 deletions test/unit/test_tutorials/test_cb_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@

set_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_id = 0 if torch.cuda.is_available() else -1

"""
This is a unit test version of the CB tutorial.
Expand All @@ -56,7 +54,7 @@ def setUp(self) -> None:

def test_cb_tutorials(self) -> None:
# load environment
device = -1
device_id = 0 if torch.cuda.is_available() else -1

# Download UCI dataset if doesn't exist
uci_data_path = "./utils/instantiations/environments/uci_datasets"
Expand All @@ -75,8 +73,8 @@ def test_cb_tutorials(self) -> None:
env = SLCBEnvironment(**pendigits_uci_dict) # pyre-ignore

# experiment code
number_of_steps = 200
record_period = 400
number_of_steps = 300
record_period = 300

"""
SquareCB
Expand All @@ -98,7 +96,7 @@ def test_cb_tutorials(self) -> None:
),
),
replay_buffer=FIFOOffPolicyReplayBuffer(100_000),
device_id=device,
device_id=device_id,
)

_ = online_learning(
Expand Down Expand Up @@ -126,7 +124,7 @@ def test_cb_tutorials(self) -> None:
exploration_module=UCBExploration(alpha=1.0),
),
replay_buffer=FIFOOffPolicyReplayBuffer(100_000),
device_id=device,
device_id=device_id,
)

_ = online_learning(
Expand Down Expand Up @@ -155,7 +153,7 @@ def test_cb_tutorials(self) -> None:
exploration_module=ThompsonSamplingExplorationLinear(),
),
replay_buffer=FIFOOffPolicyReplayBuffer(100_000),
device_id=-1,
device_id=device_id,
)

_ = online_learning(
Expand Down

0 comments on commit 84b8772

Please sign in to comment.