Skip to content

Commit

Permalink
Switching to using model name rather than random id
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Nov 20, 2024
1 parent 64c5337 commit 86e8345
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 7 deletions.
1 change: 1 addition & 0 deletions experiments/mpsc/mpsc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def run(plot=True, training=False, n_episodes=1, n_steps=None, curr_path='.', in
if config.algo in ['ppo', 'sac', 'safe_explorer_ppo', 'cpo']:
# Load state_dict from trained.
ctrl.load(f'{curr_path}/models/rl_models/{model}/model_latest.pt')
ctrl.model_name = model

# Remove temporary files and directories
shutil.rmtree(f'{curr_path}/temp', ignore_errors=True)
Expand Down
2 changes: 1 addition & 1 deletion experiments/mpsc/train_model.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#SBATCH --output=/h/pizarrob/safe-control-gym/experiments/mpsc/temp-data/mpsf_reset_%j.out
#SBATCH --error=/h/pizarrob/safe-control-gym/experiments/mpsc/temp-data/mpsf_reset_%j.err
#SBATCH --partition=cpu
#SBATCH -t 12:00:00
#SBATCH -t 24:00:00
#SBATCH --cpus-per-task=1
#SBATCH --mem=16G
#SBATCH --gres=gpu:0
Expand Down
1 change: 1 addition & 0 deletions experiments/mpsc/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def train():

# Setup MPSC.
if config.algo in ['ppo', 'sac']:
ctrl.model_name = config.output_dir.split('/')[-2]
safety_filter = make(config.safety_filter,
env_func,
**config.sf_config)
Expand Down
3 changes: 1 addition & 2 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(self,

# Adding safety filter
self.safety_filter = None
self.instance_idx = int(np.random.rand() * 1000000)

def reset(self):
'''Do initializations for training or evaluation.'''
Expand Down Expand Up @@ -308,7 +307,7 @@ def train_step(self):
info = self.info
start = time.time()
if self.safety_filter is not None and self.preserve_random_state is True:
self.save(f'./temp-data/saved_controller_prev_{self.instance_idx}.npy', save_only_random_seed=True)
self.save(f'./temp-data/saved_controller_prev_{self.model_name}.npy', save_only_random_seed=True)
for _ in range(self.rollout_steps):
with torch.no_grad():
action, v, logp = self.agent.ac.step(torch.FloatTensor(obs).to(self.device))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy')
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state:
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.instance_idx}.npy', save_only_random_seed=True)
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.instance_idx}.npy', load_only_random_seed=True)
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.model_name}.npy', save_only_random_seed=True)
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.model_name}.npy', load_only_random_seed=True)

for h in range(self.mpsc_cost_horizon):
next_step = min(iteration + h, self.env.X_GOAL.shape[0] - 1)
Expand Down Expand Up @@ -136,7 +136,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy')
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state is True:
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.instance_idx}.npy', load_only_random_seed=True)
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.instance_idx}.npy', save_only_random_seed=True)
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr_{self.uncertified_controller.model_name}.npy', load_only_random_seed=True)
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev_{self.uncertified_controller.model_name}.npy', save_only_random_seed=True)

return v_L

0 comments on commit 86e8345

Please sign in to comment.