diff --git a/src/behavior/mlp.py b/src/behavior/mlp.py index 0a89ad51..0b8df382 100644 --- a/src/behavior/mlp.py +++ b/src/behavior/mlp.py @@ -113,7 +113,11 @@ def get_action_and_value(self, nobs: torch.Tensor, action=None): ) def load_bc_weights(self, bc_weights_path): - wts = torch.load(bc_weights_path)["model_state_dict"] + + wts = torch.load(bc_weights_path) + + if "model_state_dict" in wts: + wts = wts["model_state_dict"] # Filter out keys not starting with "model" model_wts = {k: v for k, v in wts.items() if k.startswith("model")} diff --git a/src/config/base_mlp_ppo.yaml b/src/config/base_mlp_ppo.yaml index 68113a78..61c37ced 100644 --- a/src/config/base_mlp_ppo.yaml +++ b/src/config/base_mlp_ppo.yaml @@ -27,7 +27,8 @@ critic: last_layer_std: 0.25 last_layer_activation: null -init_logstd: -5.0 + init_logstd: -5.0 + learn_std: true total_timesteps: 500_000_000 num_envs: 1024 diff --git a/src/train/ppo.py b/src/train/ppo.py index a662339e..fce8c6e6 100644 --- a/src/train/ppo.py +++ b/src/train/ppo.py @@ -525,6 +525,7 @@ def main(cfg: DictConfig): ref_dist = Normal(action_mean, action_std) kl_loss = -ref_dist.log_prob(mb_new_actions).mean() policy_loss = policy_loss + cfg.kl_coef * kl_loss + # Total loss loss = policy_loss + cfg.vf_coef * v_loss