Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 12, 2024
1 parent f095c01 commit c8de01b
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8863,6 +8863,9 @@ def test_iql_batcher(
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority in ms_td.keys()

# Remove warnings
SoftUpdate(loss_fn, eps=0.5)

with torch.no_grad():
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
Expand Down Expand Up @@ -9286,6 +9289,7 @@ def test_discrete_iql(
temperature=temperature,
expectile=expectile,
loss_function="l2",
action_space="one-hot",
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -9408,6 +9412,7 @@ def test_discrete_iql_state_dict(
temperature=temperature,
expectile=expectile,
loss_function="l2",
action_space="one-hot",
)
sd = loss_fn.state_dict()
loss_fn2 = DiscreteIQLLoss(
Expand All @@ -9418,6 +9423,7 @@ def test_discrete_iql_state_dict(
temperature=temperature,
expectile=expectile,
loss_function="l2",
action_space="one-hot",
)
loss_fn2.load_state_dict(sd)

Expand All @@ -9431,6 +9437,7 @@ def test_discrete_iql_separate_losses(self, separate_losses):
value_network=value,
loss_function="l2",
separate_losses=separate_losses,
action_space="one-hot",
)
with pytest.warns(UserWarning, match="No target network updater has been"):
loss = loss_fn(td)
Expand Down Expand Up @@ -9609,6 +9616,7 @@ def test_discrete_iql_batcher(
temperature=temperature,
expectile=expectile,
loss_function="l2",
action_space="one-hot",
)

ms = MultiStep(gamma=gamma, n_steps=n).to(device)
Expand All @@ -9624,6 +9632,8 @@ def test_discrete_iql_batcher(
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority in ms_td.keys()

SoftUpdate(loss_fn, eps=0.5)

with torch.no_grad():
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
Expand Down Expand Up @@ -9695,6 +9705,7 @@ def test_discrete_iql_tensordict_keys(self, td_est):
qvalue_network=qvalue,
value_network=value,
loss_function="l2",
action_space="one-hot",
)

default_keys = {
Expand All @@ -9720,6 +9731,7 @@ def test_discrete_iql_tensordict_keys(self, td_est):
qvalue_network=qvalue,
value_network=value,
loss_function="l2",
action_space="one-hot",
)

key_mapping = {
Expand Down Expand Up @@ -9755,7 +9767,10 @@ def test_discrete_iql_notensordict(
value = self._create_mock_value(observation_key=observation_key)

loss = DiscreteIQLLoss(
actor_network=actor, qvalue_network=qvalue, value_network=value
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
action_space="one-hot",
)
loss.set_keys(
action=action_key,
Expand Down

0 comments on commit c8de01b

Please sign in to comment.