Skip to content

Commit

Permalink
Removing set_cost_function_param and setting it via rew_state_weight …
Browse files Browse the repository at this point in the history
…and rew_action_weight
  • Loading branch information
Federico-PizarroBejarano committed Dec 2, 2024
1 parent 253f25c commit 44c38bb
Show file tree
Hide file tree
Showing 20 changed files with 39 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 1, 1, 1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
2 changes: 2 additions & 0 deletions examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 0.1, 0.1, 0.1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 1, 1, 1, 1, 1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [1, 0.1, 1, 0.1, 0.1, 0.1] # Match LQR weights
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match LQR weights
rew_state_weight: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,7 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match LQR weights
rew_state_weight: [1, 0.1, 1, 0.1, 1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1]
done_on_out_of_bound: True
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights
rew_act_weight: [0.1]
done_on_out_of_bound: True

constraints:
Expand Down
2 changes: 2 additions & 0 deletions examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights
rew_act_weight: [0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ algo_config:
- 0.1
- 0.1
q_mpc:
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 0.1
- 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights
rew_act_weight: [0.1, 0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ task_config:

episode_len_sec: 6
cost: quadratic
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights
rew_act_weight: [0.1, 0.1]
done_on_out_of_bound: True

constraints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ algo_config:
- 0.1
- 0.1
q_mpc:
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 1.0
- 5.0
- 0.1
- 0.1
- 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ task_config:

episode_len_sec: 6
cost: quadratic
# Match MPC weights
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1, 0.1, 0.1, 0.1]
done_on_out_of_bound: True
constraints:
- constraint_form: default_constraint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ task_config:
proj_normal: [0, 1, 1]
episode_len_sec: 6
cost: quadratic
# Match MPC weights
rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
rew_act_weight: [0.1, 0.1, 0.1, 0.1]
done_on_out_of_bound: True
constraints:
- constraint_form: default_constraint
Expand Down
1 change: 0 additions & 1 deletion safe_control_gym/controllers/lqr/ilqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(
self.model = self.get_prior(self.env)
self.Q = get_cost_weight_matrix(self.q_lqr, self.model.nx)
self.R = get_cost_weight_matrix(self.r_lqr, self.model.nu)
self.env.set_cost_function_param(self.Q, self.R)

self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ,
self.Q, self.R, self.discrete_dynamics)
Expand Down
1 change: 0 additions & 1 deletion safe_control_gym/controllers/lqr/lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(
self.discrete_dynamics = discrete_dynamics
self.Q = get_cost_weight_matrix(q_lqr, self.model.nx)
self.R = get_cost_weight_matrix(r_lqr, self.model.nu)
self.env.set_cost_function_param(self.Q, self.R)

self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ,
self.Q, self.R, self.discrete_dynamics)
Expand Down
104 changes: 2 additions & 102 deletions safe_control_gym/controllers/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.lqr.lqr_utils import discretize_linear_system
from safe_control_gym.controllers.mpc.mpc_utils import (compute_discrete_lqr_gain_from_cont_linear_system,
compute_state_rmse, get_cost_weight_matrix,
reset_constraints, rk_discrete)
get_cost_weight_matrix, reset_constraints,
rk_discrete)
from safe_control_gym.envs.benchmark_env import Task
from safe_control_gym.envs.constraints import GENERAL_CONSTRAINTS, create_constraint_list
from safe_control_gym.utils.utils import timing
Expand Down Expand Up @@ -431,106 +431,6 @@ def setup_results_dict(self):
't_wall': []
}

def run(self,
env=None,
render=False,
logging=False,
max_steps=None,
terminate_run_on_done=None
):
'''Runs evaluation with current policy.
Args:
render (bool): if to do real-time rendering.
logging (bool): if to log on terminal.
Returns:
dict: evaluation statisitcs, rendered frames.
'''
if env is None:
env = self.env
if terminate_run_on_done is None:
terminate_run_on_done = self.terminate_run_on_done

self.x_prev = None
self.u_prev = None
if not env.initial_reset:
env.set_cost_function_param(self.Q, self.R)
# obs, info = env.reset()
obs = env.reset()
print('Init State:')
print(obs)
ep_returns, ep_lengths = [], []
frames = []
self.setup_results_dict()
self.results_dict['obs'].append(obs)
self.results_dict['state'].append(env.state)
i = 0
if env.TASK == Task.STABILIZATION:
if max_steps is None:
MAX_STEPS = int(env.CTRL_FREQ * env.EPISODE_LEN_SEC)
else:
MAX_STEPS = max_steps
elif env.TASK == Task.TRAJ_TRACKING:
if max_steps is None:
MAX_STEPS = self.traj.shape[1]
else:
MAX_STEPS = max_steps
else:
raise Exception('Undefined Task')
self.terminate_loop = False
done = False
common_metric = 0
while not (done and terminate_run_on_done) and i < MAX_STEPS and not (self.terminate_loop):
action = self.select_action(obs)
if self.terminate_loop:
print('Infeasible MPC Problem')
break
# Repeat input for more efficient control.
obs, reward, done, info = env.step(action)
self.results_dict['obs'].append(obs)
self.results_dict['reward'].append(reward)
self.results_dict['done'].append(done)
self.results_dict['info'].append(info)
self.results_dict['action'].append(action)
self.results_dict['state'].append(env.state)
self.results_dict['state_mse'].append(info['mse'])
self.results_dict['state_error'].append(env.state - env.X_GOAL[i, :])
common_metric += info['mse']
print(i, '-th step.')
print('action:', action)
print('obs', obs)
print('reward', reward)
print('done', done)
print(info)
print()
if render:
env.render()
frames.append(env.render('rgb_array'))
i += 1
# Collect evaluation results.
ep_lengths = np.asarray(ep_lengths)
ep_returns = np.asarray(ep_returns)
if logging:
msg = '****** Evaluation ******\n'
msg += 'eval_ep_length {:.2f} +/- {:.2f} | eval_ep_return {:.3f} +/- {:.3f}\n'.format(
ep_lengths.mean(), ep_lengths.std(), ep_returns.mean(),
ep_returns.std())
if len(frames) != 0:
self.results_dict['frames'] = frames
self.results_dict['obs'] = np.vstack(self.results_dict['obs'])
self.results_dict['state'] = np.vstack(self.results_dict['state'])
try:
self.results_dict['reward'] = np.vstack(self.results_dict['reward'])
self.results_dict['action'] = np.vstack(self.results_dict['action'])
self.results_dict['full_traj_common_cost'] = common_metric
self.results_dict['total_rmse_state_error'] = compute_state_rmse(self.results_dict['state'])
self.results_dict['total_rmse_obs_error'] = compute_state_rmse(self.results_dict['obs'])
except ValueError:
raise Exception('[ERROR] mpc.run().py: MPC could not find a solution for the first step given the initial conditions. '
'Check to make sure initial conditions are feasible.')
return deepcopy(self.results_dict)

def reset_before_run(self, obs, info=None, env=None):
'''Reinitialize just the controller before a new run.
Expand Down
23 changes: 0 additions & 23 deletions safe_control_gym/envs/benchmark_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,6 @@ def __init__(self,
self.state_dim = self.state_space.shape[0]
else:
self.state_dim = self.obs_dim
# Default Q and R matrices for quadratic cost.
if self.COST == Cost.QUADRATIC:
self.Q = np.eye(self.observation_space.shape[0])
self.R = np.eye(self.action_space.shape[0])
# Set constraint info.
self.CONSTRAINTS = constraints
self.DONE_ON_VIOLATION = done_on_violation
Expand Down Expand Up @@ -221,25 +217,6 @@ def seed(self,
disturbs.seed(self)
return [seed]

def set_cost_function_param(self,
Q,
R
):
'''Set the cost function parameters.
Args:
Q (ndarray): State weight matrix (nx by nx).
R (ndarray): Input weight matrix (nu by nu).
'''

if not self.initial_reset:
self.Q = Q
self.R = R
else:
raise RuntimeError(
'[ERROR] env.set_cost_function_param() cannot be called after the first reset of the environment.'
)

def set_adversary_control(self, action):
'''Sets disturbance by an adversary controller, called before (each) step().
Expand Down
2 changes: 2 additions & 0 deletions safe_control_gym/envs/gym_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def __init__(self,
self.obs_goal_horizon = obs_goal_horizon
self.obs_wrap_angle = obs_wrap_angle
self.rew_state_weight = np.array(rew_state_weight, ndmin=1, dtype=float)
self.Q = np.diag(self.rew_state_weight)
self.rew_act_weight = np.array(rew_act_weight, ndmin=1, dtype=float)
self.R = np.diag(self.rew_act_weight)
self.rew_exponential = rew_exponential
self.done_on_out_of_bound = done_on_out_of_bound
# BenchmarkEnv constructor, called after defining the custom args,
Expand Down
2 changes: 2 additions & 0 deletions safe_control_gym/envs/gym_pybullet_drones/quadrotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def __init__(self,
self.norm_act_scale = norm_act_scale
self.obs_goal_horizon = obs_goal_horizon
self.rew_state_weight = np.array(rew_state_weight, ndmin=1, dtype=float)
self.Q = np.diag(self.rew_state_weight)
self.rew_act_weight = np.array(rew_act_weight, ndmin=1, dtype=float)
self.R = np.diag(self.rew_act_weight)
self.rew_exponential = rew_exponential
self.done_on_out_of_bound = done_on_out_of_bound
if info_mse_metric_state_weight is None:
Expand Down

0 comments on commit 44c38bb

Please sign in to comment.