Skip to content

Commit

Permalink
Adding support for resetting the random state to make precomputed not…
Browse files Browse the repository at this point in the history
… ignore errors
  • Loading branch information
Federico-PizarroBejarano committed Nov 19, 2024
1 parent 21f621a commit 43f0f44
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ algo_config:
penalize_sf_diff: True
sf_penalty: 75
use_safe_reset: True
decay_factor_curriculum: False
preserve_random_state: False
9 changes: 7 additions & 2 deletions experiments/mpsc/train_all_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
sbatch train_model.sbatch False 1 1
for MPSC_COST_HORIZON in 2 5 10 20; do
for DECAY_FACTOR in 0.25 0.5 0.75 1; do
sbatch train_model.sbatch True $MPSC_COST_HORIZON $DECAY_FACTOR
sbatch train_model.sbatch False $MPSC_COST_HORIZON $DECAY_FACTOR
# Ignore precomputed differences
sbatch train_model.sbatch False $MPSC_COST_HORIZON $DECAY_FACTOR False
sbatch train_model.sbatch True $MPSC_COST_HORIZON $DECAY_FACTOR False

# Preserve random state
sbatch train_model.sbatch False $MPSC_COST_HORIZON $DECAY_FACTOR True
sbatch train_model.sbatch True $MPSC_COST_HORIZON $DECAY_FACTOR True
done
done
5 changes: 3 additions & 2 deletions experiments/mpsc/train_model.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ MPSC_COST_HORIZON=$2
DECAY_FACTOR=$3
SF_PENALTY=1.0

TAG="curriculum_$1_$2_$3"
TAG="curriculum_$1_$2_$3_$4"
echo $TAG $SYS $ALGO $TASK

# Train the unsafe controller/agent.
Expand All @@ -45,6 +45,7 @@ python3 train_rl.py \
algo_config.use_safe_reset=True \
algo_config.penalize_sf_diff=True \
algo_config.sf_penalty=$SF_PENALTY \
algo_config.decay_factor_curriculum=$1
algo_config.decay_factor_curriculum=$1 \
algo_config.preserve_random_state=$4

./mpsc_experiment.sh $TAG $SYS $TASK $ALGO
32 changes: 25 additions & 7 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from safe_control_gym.math_and_models.normalization import (BaseNormalizer, MeanStdNormalizer,
RewardStdNormalizer)
from safe_control_gym.utils.logging import ExperimentLogger
from safe_control_gym.utils.utils import get_random_state, is_wrapped, set_random_state
from safe_control_gym.utils.utils import is_wrapped


class PPO(BaseController):
Expand Down Expand Up @@ -120,8 +120,15 @@ def close(self):

def save(self,
path,
save_only_random_seed=False,
):
'''Saves model params and experiment state to checkpoint path.'''
if save_only_random_seed is True:
exp_state = {
'env_random_state': self.env.get_env_random_state()
}
torch.save(exp_state, path)
return
path_dir = os.path.dirname(path)
os.makedirs(path_dir, exist_ok=True)
state_dict = {
Expand All @@ -133,17 +140,20 @@ def save(self,
exp_state = {
'total_steps': self.total_steps,
'obs': self.obs,
'random_state': get_random_state(),
'env_random_state': self.env.get_env_random_state()
}
state_dict.update(exp_state)
torch.save(state_dict, path)

def load(self,
path,
load_only_random_seed=False,
):
'''Restores model and experiment given checkpoint path.'''
state = torch.load(path)
if load_only_random_seed is True:
self.env.set_env_random_state(state['env_random_state'])
return
# Restore policy.
self.agent.load_state_dict(state['agent'])
self.obs_normalizer.load_state_dict(state['obs_normalizer'])
Expand All @@ -152,7 +162,6 @@ def load(self,
if self.training:
self.total_steps = state['total_steps']
self.obs = state['obs']
set_random_state(state['random_state'])
self.env.set_env_random_state(state['env_random_state'])
self.logger.load(self.total_steps)

Expand Down Expand Up @@ -192,7 +201,7 @@ def learn(self,
if self.log_interval and self.total_steps % self.log_interval == 0:
self.log_step(results)

def select_action(self, obs, info=None):
def select_action(self, obs, info=None, training=False):
'''Determine the action to take at the current timestep.
Args:
Expand All @@ -203,9 +212,14 @@ def select_action(self, obs, info=None):
action (ndarray): The action chosen by the controller.
'''

with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
action = self.agent.ac.act(obs)
if not training:
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
action = self.agent.ac.act(obs)
else:
with torch.no_grad():
obs = torch.FloatTensor(obs).to(self.device)
action, _, _ = self.agent.ac.step(obs)

return action

Expand All @@ -216,6 +230,7 @@ def run(self,
verbose=False,
):
'''Runs evaluation with current policy.'''
self.curr_training = False
self.agent.eval()
self.obs_normalizer.set_read_only()
if env is None:
Expand Down Expand Up @@ -283,13 +298,16 @@ def run(self,

def train_step(self):
'''Performs a training/fine-tuning step.'''
self.curr_training = True
self.agent.train()
self.obs_normalizer.unset_read_only()
rollouts = PPOBuffer(self.env.observation_space, self.env.action_space, self.rollout_steps, self.rollout_batch_size)
obs = self.obs
true_obs = self.true_obs
info = self.info
start = time.time()
if self.safety_filter is not None and self.preserve_random_state is True:
self.save('./temp-data/saved_controller_prev.npy', save_only_random_seed=True)
for _ in range(self.rollout_steps):
with torch.no_grad():
action, v, logp = self.agent.ac.step(torch.FloatTensor(obs).to(self.device))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from safe_control_gym.controllers.pid.pid import PID
from safe_control_gym.controllers.ppo.ppo import PPO
from safe_control_gym.envs.benchmark_env import Environment
from safe_control_gym.envs.env_wrappers.vectorized_env.vec_env import VecEnv
from safe_control_gym.safety_filters.mpsc.mpsc_cost_function.abstract_cost import MPSC_COST
Expand Down Expand Up @@ -99,15 +100,19 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
if isinstance(self.uncertified_controller, PID):
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy')
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state:
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_curr.npy', save_only_random_seed=True)
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_prev.npy', load_only_random_seed=True)

for h in range(self.mpsc_cost_horizon):
next_step = min(iteration + h, self.env.X_GOAL.shape[0] - 1)
# Concatenate goal info (goal state(s)) for RL
extended_obs = self.env.extend_obs(obs, next_step + 1)

info = {'current_step': next_step}

action = self.uncertified_controller.select_action(obs=extended_obs, info=info)
if isinstance(self.uncertified_controller, PPO):
action = self.uncertified_controller.select_action(obs=extended_obs, info={'current_step': next_step}, training=self.uncertified_controller.curr_training)
else:
action = self.uncertified_controller.select_action(obs=extended_obs, info={'current_step': next_step})

if uncert_env.NORMALIZED_RL_ACTION_SPACE:
if self.env.NAME == Environment.CARTPOLE:
Expand All @@ -117,8 +122,11 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):

action = np.clip(action, self.env.physical_action_bounds[0], self.env.physical_action_bounds[1])

# if h == 0 and np.linalg.norm(uncertified_action - action) >= 0.001:
# raise ValueError(f'[ERROR] Mismatch between unsafe controller and MPSC guess. Uncert: {uncertified_action}, Guess: {action}, Diff: {np.linalg.norm(uncertified_action - action)}.')
if h == 0 \
and np.linalg.norm(uncertified_action - action) >= 0.001 \
and np.linalg.norm(uncertified_action - uncert_env.hover_thrust * np.ones(uncertified_action.shape)) >= 0.001\
and self.uncertified_controller.preserve_random_state is True:
raise ValueError(f'[ERROR] Mismatch between unsafe controller and MPSC guess. Uncert: {uncertified_action}, Guess: {action}, Diff: {np.linalg.norm(uncertified_action - action)}.')

v_L[:, h:h + 1] = action.reshape((self.model.nu, 1))

Expand All @@ -127,5 +135,8 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
if isinstance(self.uncertified_controller, PID):
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy')
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy')
elif isinstance(self.uncertified_controller, PPO) and self.uncertified_controller.curr_training is True and self.uncertified_controller.preserve_random_state is True:
self.uncertified_controller.load(f'{self.output_dir}/temp-data/saved_controller_curr.npy', load_only_random_seed=True)
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy', save_only_random_seed=True)

return v_L

0 comments on commit 43f0f44

Please sign in to comment.