From 86e8345f07e9a2abc5e4f60d36fd82c5418528d1 Mon Sep 17 00:00:00 2001 From: Federico Pizarro Bejarano Date: Wed, 20 Nov 2024 10:00:30 -0500 Subject: [PATCH] Switching to using model name rather than random id --- experiments/mpsc/mpsc_experiment.py | 1 + experiments/mpsc/train_model.sbatch | 2 +- experiments/mpsc/train_rl.py | 1 + safe_control_gym/controllers/ppo/ppo.py | 3 +-- .../mpsc/mpsc_cost_function/precomputed_cost.py | 8 ++++---- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/experiments/mpsc/mpsc_experiment.py b/experiments/mpsc/mpsc_experiment.py index 56f6e3180..9fe6228dd 100644 --- a/experiments/mpsc/mpsc_experiment.py +++ b/experiments/mpsc/mpsc_experiment.py @@ -73,6 +73,7 @@ def run(plot=True, training=False, n_episodes=1, n_steps=None, curr_path='.', in if config.algo in ['ppo', 'sac', 'safe_explorer_ppo', 'cpo']: # Load state_dict from trained. ctrl.load(f'{curr_path}/models/rl_models/{model}/model_latest.pt') + ctrl.model_name = model # Remove temporary files and directories shutil.rmtree(f'{curr_path}/temp', ignore_errors=True) diff --git a/experiments/mpsc/train_model.sbatch b/experiments/mpsc/train_model.sbatch index 64839ae6d..550243b28 100755 --- a/experiments/mpsc/train_model.sbatch +++ b/experiments/mpsc/train_model.sbatch @@ -3,7 +3,7 @@ #SBATCH --output=/h/pizarrob/safe-control-gym/experiments/mpsc/temp-data/mpsf_reset_%j.out #SBATCH --error=/h/pizarrob/safe-control-gym/experiments/mpsc/temp-data/mpsf_reset_%j.err #SBATCH --partition=cpu -#SBATCH -t 12:00:00 +#SBATCH -t 24:00:00 #SBATCH --cpus-per-task=1 #SBATCH --mem=16G #SBATCH --gres=gpu:0 diff --git a/experiments/mpsc/train_rl.py b/experiments/mpsc/train_rl.py index c69f14c87..340d43d2b 100644 --- a/experiments/mpsc/train_rl.py +++ b/experiments/mpsc/train_rl.py @@ -53,6 +53,7 @@ def train(): # Setup MPSC. if config.algo in ['ppo', 'sac']: + ctrl.model_name = config.output_dir.split('/')[-2] safety_filter = make(config.safety_filter, env_func, **config.sf_config) diff --git a/safe_control_gym/controllers/ppo/ppo.py b/safe_control_gym/controllers/ppo/ppo.py index b6b8e6108..3ab94448f 100644 --- a/safe_control_gym/controllers/ppo/ppo.py +++ b/safe_control_gym/controllers/ppo/ppo.py @@ -90,7 +90,6 @@ def __init__(self, # Adding safety filter self.safety_filter = None - self.instance_idx = int(np.random.rand() * 1000000) def reset(self): '''Do initializations for training or evaluation.''' @@ -308,7 +307,7 @@ def train_step(self): info = self.info start = time.time() if self.safety_filter is not None and self.preserve_random_state is True: - self.save(f'./temp-data/saved_controller_prev_{self.instance_idx}.npy', save_only_random_seed=True) + self.save(f'./temp-data/saved_controller_prev_{self.model_name}.npy', save_only_random_seed=True) for _ in range(self.rollout_steps): with torch.no_grad(): action, v, logp = self.agent.ac.step(torch.FloatTensor(obs).to(self.device)) diff --git a/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py b/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py index 22cb4c362..e8ff9a3b9 100644 --- a/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py +++ b/safe_control_gym/safety_filters/mpsc/mpsc_cost_function/precomputed_cost.py @@ -101,8 +101,8 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration): self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy') self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy') elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state: - self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.instance_idx}.npy', save_only_random_seed=True) - self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.instance_idx}.npy', load_only_random_seed=True) + self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.model_name}.npy', save_only_random_seed=True) + self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.model_name}.npy', load_only_random_seed=True) for h in range(self.mpsc_cost_horizon): next_step = min(iteration + h, self.env.X_GOAL.shape[0] - 1) @@ -136,7 +136,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration): self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy') self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy') elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state is True: - self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.instance_idx}.npy', load_only_random_seed=True) - self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.instance_idx}.npy', save_only_random_seed=True) + self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.model_name}.npy', load_only_random_seed=True) + self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.model_name}.npy', save_only_random_seed=True) return v_L