Skip to content

Commit

Permalink
[Feature] Add LossModule.reset_parameters_recursive (#2546)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler authored Nov 9, 2024
1 parent 35a7813 commit 218d5bf
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 0 deletions.
217 changes: 217 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def get_devices():


class LossModuleTestBase:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
assert hasattr(
cls, "test_reset_parameters_recursive"
), "Please add a test_reset_parameters_recursive test for this class"

def _flatten_in_keys(self, in_keys):
return [
in_key if isinstance(in_key, str) else "_".join(list(unravel_keys(in_key)))
Expand Down Expand Up @@ -252,6 +258,42 @@ def set_advantage_keys_through_loss_test(
getattr(test_fn.value_estimator.tensor_keys, advantage_key) == new_key
)

@classmethod
def reset_parameters_recursive_test(cls, loss_fn):
def get_params(loss_fn):
for key, item in loss_fn.__dict__.items():
if isinstance(item, nn.Module):
module_name = key
params_name = f"{module_name}_params"
target_name = f"target_{module_name}_params"
params = loss_fn._modules.get(params_name, None)
target = loss_fn._modules.get(target_name, None)

if params is not None:
yield params_name, params._param_td

else:
for subparam_name, subparam in loss_fn.named_parameters():
if module_name in subparam_name:
yield subparam_name, subparam

if target is not None:
yield target_name, target

old_params = {}

for param_name, param in get_params(loss_fn):
with torch.no_grad():
# Change the parameter to ensure that reset will change it again
param += 1000
old_params[param_name] = param.clone()

loss_fn.reset_parameters_recursive()

for param_name, param in get_params(loss_fn):
old_param = old_params[param_name]
assert (param != old_param).any()


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("vmap_randomness", (None, "different", "same", "error"))
Expand Down Expand Up @@ -494,6 +536,11 @@ def _create_seq_mock_data_dqn(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor(action_spec_type="one_hot")
loss_fn = DQNLoss(actor)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize(
"delay_value,double_dqn", ([False, False], [True, False], [True, True])
)
Expand Down Expand Up @@ -1066,6 +1113,12 @@ def _create_mock_data_dqn(
td.refine_names(None, "time")
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor(action_spec_type="one_hot")
mixer = self._create_mock_mixer()
loss_fn = QMixerLoss(actor, mixer)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -1570,6 +1623,12 @@ def _create_seq_mock_data_ddpg(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = DDPGLoss(actor, value)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)])
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
Expand Down Expand Up @@ -2210,6 +2269,16 @@ def _create_seq_mock_data_td3(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = TD3Loss(
actor,
value,
bounds=(-1, 1),
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2916,6 +2985,16 @@ def _create_seq_mock_data_td3bc(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = TD3BCLoss(
actor,
value,
bounds=(-1, 1),
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
Expand Down Expand Up @@ -3720,6 +3799,20 @@ def _create_seq_mock_data_sac(
)
return td

def test_reset_parameters_recursive(self, version):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
if version == 1:
value = self._create_mock_value()
else:
value = None
loss_fn = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
Expand Down Expand Up @@ -4591,6 +4684,17 @@ def _create_seq_mock_data_sac(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
loss_fn = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
num_actions=actor.spec["action"].space.n,
action_space="one-hot",
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [2])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -5227,6 +5331,15 @@ def _create_seq_mock_data_crossq(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
Expand Down Expand Up @@ -5962,6 +6075,15 @@ def _create_seq_mock_data_redq(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
loss_fn = REDQLoss(
actor_network=actor,
qvalue_network=qvalue,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -6792,6 +6914,15 @@ def _create_seq_mock_data_cql(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
loss_fn = CQLLoss(
actor_network=actor,
qvalue_network=qvalue,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, True))
@pytest.mark.parametrize("max_q_backup", [True, False])
Expand Down Expand Up @@ -7367,6 +7498,13 @@ def _create_seq_mock_data_dcql(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor(
action_spec_type="one_hot",
)
loss_fn = DiscreteCQLLoss(actor)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -7938,6 +8076,13 @@ def _create_seq_mock_data_ppo(

return td

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
def test_reset_parameters_recursive(self, loss_class):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = loss_class(actor, value)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
Expand Down Expand Up @@ -9016,6 +9161,12 @@ def _create_seq_mock_data_a2c(
td["scale"] = scale
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
value = self._create_mock_value()
loss_fn = A2CLoss(actor, value)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -9624,6 +9775,27 @@ def test_a2c_value_clipping(self, clip_value, device, composite_action_dist):
class TestReinforce(LossModuleTestBase):
seed = 0

def test_reset_parameters_recursive(self):
n_obs = 3
n_act = 5
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
module = TensorDictModule(
net, in_keys=["observation"], out_keys=["loc", "scale"]
)
actor_net = ProbabilisticActor(
module,
distribution_class=TanhNormal,
return_log_prob=True,
in_keys=["loc", "scale"],
spec=Unbounded(n_act),
)
loss_fn = ReinforceLoss(
actor_net,
critic_network=value_net,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("gradient_mode", [True, False])
@pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -10323,6 +10495,11 @@ def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13):
value_model(td)
return value_model

def test_reset_parameters_recursive(self, device):
world_model = self._create_world_model_model(10, 5).to(device)
loss_fn = DreamerModelLoss(world_model)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("lambda_kl", [0, 1.0])
@pytest.mark.parametrize("lambda_reco", [0, 1.0])
@pytest.mark.parametrize("lambda_reward", [0, 1.0])
Expand Down Expand Up @@ -10604,6 +10781,11 @@ def _create_seq_mock_data_odt(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
loss_fn = OnlineDTLoss(actor)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("device", get_available_devices())
def test_odt(self, device):
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -10831,6 +11013,11 @@ def _create_seq_mock_data_dt(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
loss_fn = DTLoss(actor)
self.reset_parameters_recursive_test(loss_fn)

def test_dt_tensordict_keys(self):
actor = self._create_mock_actor()
loss_fn = DTLoss(actor)
Expand Down Expand Up @@ -11034,6 +11221,11 @@ def _create_seq_mock_data_gail(
)
return td

def test_reset_parameters_recursive(self):
discriminator = self._create_mock_discriminator()
loss_fn = GAILLoss(discriminator)
self.reset_parameters_recursive_test(loss_fn)

def test_gail_tensordict_keys(self):
discriminator = self._create_mock_discriminator()
loss_fn = GAILLoss(discriminator)
Expand Down Expand Up @@ -11406,6 +11598,17 @@ def _create_seq_mock_data_iql(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
value = self._create_mock_value()
loss_fn = IQLLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
Expand Down Expand Up @@ -12214,6 +12417,18 @@ def _create_seq_mock_data_discrete_iql(
)
return td

def test_reset_parameters_recursive(self):
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
value = self._create_mock_value()
loss_fn = DiscreteIQLLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
action_space="one-hot",
)
self.reset_parameters_recursive_test(loss_fn)

@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
Expand Down Expand Up @@ -12842,6 +13057,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
)
loss = MyLoss(actor_module)

LossModuleTestBase.reset_parameters_recursive_test(loss)

if create_target_params:
SoftUpdate(loss, eps=0.5)

Expand Down
Loading

0 comments on commit 218d5bf

Please sign in to comment.