Skip to content

Commit

Permalink
Small clean-up and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pizarrob committed Nov 16, 2023
1 parent 49c4ad5 commit 7cf0d2e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 22 deletions.
4 changes: 2 additions & 2 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ def run(self,
physical_action = env.denormalize_action(action)
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success and self.filter_train_actions is True:
if success:
action = env.normalize_action(certified_action)
else:
self.safety_filter.ocp_solver.reset()
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success and self.filter_train_actions is True:
if success:
action = self.env.envs[0].normalize_action(certified_action)

action = np.atleast_2d(np.squeeze([action]))
Expand Down
4 changes: 2 additions & 2 deletions safe_control_gym/controllers/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ def run(self, env=None, render=False, n_episodes=10, verbose=False, **kwargs):
physical_action = env.denormalize_action(action)
unextended_obs = np.squeeze(true_obs)[:env.symbolic.nx]
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success and self.filter_train_actions is True:
if success:
applied_action = env.normalize_action(certified_action)
else:
self.safety_filter.ocp_solver.reset()
certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info)
if success and self.filter_train_actions is True:
if success:
applied_action = self.env.envs[0].normalize_action(certified_action)

action = np.atleast_2d(np.squeeze([applied_action]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self,

self.output_dir = output_dir
self.uncertified_controller = None
self.skip_checks = False

def get_cost(self, opti_dict):
'''Returns the cost function for the MPSC optimization in symbolic form.
Expand Down Expand Up @@ -106,10 +105,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
# Concatenate goal info (goal state(s)) for RL
extended_obs = self.env.extend_obs(obs, next_step + 1)

info = {
'current_step': next_step,
'constraint_values': np.concatenate([self.get_constraint_value(con, obs) for con in self.env.constraints.state_constraints])
}
info = {'current_step': next_step}

action = self.uncertified_controller.select_action(obs=extended_obs, info=info)

Expand All @@ -121,7 +117,7 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):

action = np.clip(action, self.env.physical_action_bounds[0], self.env.physical_action_bounds[1])

if h == 0 and np.linalg.norm(uncertified_action - action) >= 0.001 and not self.skip_checks:
if h == 0 and np.linalg.norm(uncertified_action - action) >= 0.001:
raise ValueError(f'[ERROR] Mismatch between unsafe controller and MPSC guess. Uncert: {uncertified_action}, Guess: {action}, Diff: {np.linalg.norm(uncertified_action - action)}.')

v_L[:, h:h + 1] = action.reshape((self.model.nu, 1))
Expand All @@ -133,15 +129,3 @@ def calculate_unsafe_path(self, obs, uncertified_action, iteration):
self.uncertified_controller.save(f'{self.output_dir}/temp-data/saved_controller_prev.npy')

return v_L

def get_constraint_value(self, con, state):
'''Gets the value of a constraint given the state.
Args:
con (Constraint): The constraint.
state (ndarray): The state to be tested.
Returns:
value (float): The value of the constraint at the given state.
'''
return np.round(np.atleast_1d(np.squeeze(con.sym_func(np.array(state, ndmin=1)))), decimals=con.decimals)

0 comments on commit 7cf0d2e

Please sign in to comment.