Skip to content

Commit

Permalink
cleanup add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Nov 3, 2023
1 parent 60cf8ec commit 97dd57f
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(cfg: "DictConfig"): # noqa: F821
q_loss.backward()
optimizer.step()
q_losses.append(q_loss.item())
cql_loss.append(loss_dict["cql_loss"].item())
cql_loss.append(loss_dict["loss_cql"].item())

# Update target params
target_net_updater.step()
Expand Down
362 changes: 362 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
ClipPPOLoss,
CQLLoss,
DDPGLoss,
DiscreteCQLLoss,
DiscreteSACLoss,
DistributionalDQNLoss,
DQNLoss,
Expand Down Expand Up @@ -5164,6 +5165,367 @@ def test_cql_batcher(
)


class TestDiscreteCQL(LossModuleTestBase):
seed = 0

def _create_mock_actor(
self,
action_spec_type,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
is_nn_module=False,
action_value_key=None,
):
# Actor
if action_spec_type == "one_hot":
action_spec = OneHotDiscreteTensorSpec(action_dim)
elif action_spec_type == "categorical":
action_spec = DiscreteTensorSpec(action_dim)
# elif action_spec_type == "nd_bounded":
# action_spec = BoundedTensorSpec(
# -torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
# )
else:
raise ValueError(f"Wrong {action_spec_type}")

module = nn.Linear(obs_dim, action_dim)
if is_nn_module:
return module.to(device)
actor = QValueActor(
spec=CompositeSpec(
{
"action": action_spec,
"action_value"
if action_value_key is None
else action_value_key: None,
"chosen_action_value": None,
},
shape=[],
),
action_space=action_spec_type,
module=module,
action_value_key=action_value_key,
).to(device)
return actor

def _create_mock_data_dcql(
self,
action_spec_type,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
action_key="action",
action_value_key="action_value",
):
# create a tensordict
obs = torch.randn(batch, obs_dim)
next_obs = torch.randn(batch, obs_dim)

action_value = torch.randn(batch, action_dim)
action = (action_value == action_value.max(-1, True)[0]).to(torch.long)

if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=False)
reward = torch.randn(batch, 1)
done = torch.zeros(batch, 1, dtype=torch.bool)
terminated = torch.zeros(batch, 1, dtype=torch.bool)
td = TensorDict(
batch_size=(batch,),
source={
"observation": obs,
"next": {
"observation": next_obs,
"done": done,
"terminated": terminated,
"reward": reward,
},
action_key: action,
action_value_key: action_value,
},
device=device,
)
return td

def _create_seq_mock_data_dcql(
self,
action_spec_type,
batch=2,
T=4,
obs_dim=3,
action_dim=4,
device="cpu",
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
obs = total_obs[:, :T]
next_obs = total_obs[:, 1:]

action_value = torch.randn(batch, T, action_dim, device=device)
action = (action_value == action_value.max(-1, True)[0]).to(torch.long)

# action_value = action_value.unsqueeze(-1)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=False)
action = action.masked_fill_(~mask, 0.0)
else:
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
td = TensorDict(
batch_size=(batch, T),
source={
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"next": {
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
"done": done,
"terminated": terminated,
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
"action": action,
"action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
names=[None, "time"],
)
return td

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_dcql(self, delay_value, device, action_spec_type, td_est):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(
action_spec_type=action_spec_type, device=device
)
td = self._create_mock_data_dcql(
action_spec_type=action_spec_type, device=device
)
loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value)
if td_est is ValueEstimators.GAE:
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
if td_est is not None:
loss_fn.make_value_estimator(td_est)
with (
pytest.warns(UserWarning, match="No target network updater has been")
if delay_value
else contextlib.nullcontext()
), _check_td_steady(td):
loss = loss_fn(td)
assert loss_fn.tensor_keys.priority in td.keys()

sum([item for _, item in loss.items()]).backward()
assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0

# Check param update effect on targets
target_value = loss_fn.target_value_network_params.clone()
for p in loss_fn.parameters():
if p.requires_grad:
p.data += torch.randn_like(p)
target_value2 = loss_fn.target_value_network_params.clone()
if loss_fn.delay_value:
assert_allclose_td(target_value, target_value2)
else:
assert not (target_value == target_value2).any()

# check that policy is updated after parameter update
parameters = [p.clone() for p in actor.parameters()]
for p in loss_fn.parameters():
p.data += torch.randn_like(p)
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
def test_dcql_state_dict(self, delay_value, device, action_spec_type):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(
action_spec_type=action_spec_type, device=device
)
loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value)
sd = loss_fn.state_dict()
loss_fn2 = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(
action_spec_type=action_spec_type, device=device
)

td = self._create_seq_mock_data_dcql(
action_spec_type=action_spec_type, device=device
)
loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value)

ms = MultiStep(gamma=gamma, n_steps=n).to(device)
ms_td = ms(td.clone())

with (
pytest.warns(UserWarning, match="No target network updater has been")
if delay_value
else contextlib.nullcontext()
), _check_td_steady(ms_td):
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority in ms_td.keys()

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum([item for _, item in loss.items()])
_loss_ms = sum([item for _, item in loss_ms.items()])
assert (
abs(_loss - _loss_ms) < 1e-3
), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0"
else:
with pytest.raises(AssertionError):
assert_allclose_td(loss, loss_ms)
sum([item for _, item in loss_ms.items()]).backward()
assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0

# Check param update effect on targets
target_value = loss_fn.target_value_network_params.clone()
for p in loss_fn.parameters():
if p.requires_grad:
p.data += torch.randn_like(p)
target_value2 = loss_fn.target_value_network_params.clone()
if loss_fn.delay_value:
assert_allclose_td(target_value, target_value2)
else:
assert not (target_value == target_value2).any()

# check that policy is updated after parameter update
parameters = [p.clone() for p in actor.parameters()]
for p in loss_fn.parameters():
p.data += torch.randn_like(p)
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))

@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
)
def test_dcql_tensordict_keys(self, td_est):
torch.manual_seed(self.seed)
action_spec_type = "one_hot"
actor = self._create_mock_actor(action_spec_type=action_spec_type)
loss_fn = DQNLoss(actor)

default_keys = {
"value_target": "value_target",
"value": "chosen_action_value",
"priority": "td_error",
"action_value": "action_value",
"action": "action",
"reward": "reward",
"done": "done",
"terminated": "terminated",
}

self.tensordict_keys_test(loss_fn, default_keys=default_keys)

loss_fn = DiscreteCQLLoss(actor)
key_mapping = {
"reward": ("reward", "reward_test"),
"done": ("done", ("done", "test")),
"terminated": ("terminated", ("terminated", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

actor = self._create_mock_actor(
action_spec_type=action_spec_type, action_value_key="chosen_action_value_2"
)
loss_fn = DiscreteCQLLoss(actor)
key_mapping = {
"value": ("value", "chosen_action_value_2"),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

@pytest.mark.parametrize("action_spec_type", ("categorical", "one_hot"))
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
)
def test_dcql_tensordict_run(self, action_spec_type, td_est):
torch.manual_seed(self.seed)
tensor_keys = {
"action_value": "action_value_test",
"action": "action_test",
"priority": "priority_test",
}
actor = self._create_mock_actor(
action_spec_type=action_spec_type,
action_value_key=tensor_keys["action_value"],
)
td = self._create_mock_data_dcql(
action_spec_type=action_spec_type,
action_key=tensor_keys["action"],
action_value_key=tensor_keys["action_value"],
)

loss_fn = DiscreteCQLLoss(actor, loss_function="l2")
loss_fn.set_keys(**tensor_keys)

if td_est is not None:
loss_fn.make_value_estimator(td_est)
with _check_td_steady(td):
_ = loss_fn(td)
assert loss_fn.tensor_keys.priority in td.keys()

@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
def test_dcql_notensordict(
self, observation_key, reward_key, done_key, terminated_key
):
n_obs = 3
n_action = 4
action_spec = OneHotDiscreteTensorSpec(n_action)
module = nn.Linear(n_obs, n_action) # a simple value model
actor = QValueActor(
spec=action_spec,
action_space="one_hot",
module=module,
in_keys=[observation_key],
)
loss = DiscreteCQLLoss(actor)
loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key)
# define data
observation = torch.randn(n_obs)
next_observation = torch.randn(n_obs)
action = action_spec.rand()
next_reward = torch.randn(1)
next_done = torch.zeros(1, dtype=torch.bool)
next_terminated = torch.zeros(1, dtype=torch.bool)
kwargs = {
observation_key: observation,
f"next_{observation_key}": next_observation,
f"next_{reward_key}": next_reward,
f"next_{done_key}": next_done,
f"next_{terminated_key}": next_terminated,
"action": action,
}
td = TensorDict(kwargs, []).unflatten_keys("_")
loss_val = loss(**kwargs)

loss_val_td = loss(td)

torch.testing.assert_close(loss_val_td.get(loss.out_keys[0]), loss_val[0])
torch.testing.assert_close(loss_val_td.get(loss.out_keys[1]), loss_val[1])


class TestPPO(LossModuleTestBase):
seed = 0

Expand Down
Loading

0 comments on commit 97dd57f

Please sign in to comment.