diff --git a/examples/lqr/config_overrides/cartpole/cartpole_stabilization.yaml b/examples/lqr/config_overrides/cartpole/cartpole_stabilization.yaml index 0cbb7226e..eefdf5dca 100644 --- a/examples/lqr/config_overrides/cartpole/cartpole_stabilization.yaml +++ b/examples/lqr/config_overrides/cartpole/cartpole_stabilization.yaml @@ -30,4 +30,6 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [1, 1, 1, 1] # Match LQR weights + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml b/examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml index 7eefd0fb0..11b956747 100644 --- a/examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml +++ b/examples/lqr/config_overrides/cartpole/cartpole_tracking.yaml @@ -33,4 +33,6 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [1, 0.1, 0.1, 0.1] # Match LQR weights + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml b/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml index 134575710..8132fbca0 100644 --- a/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml +++ b/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml @@ -42,4 +42,6 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [1, 1, 1, 1, 1, 1] # Match LQR weights + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml b/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml index aa68b6912..8cf000795 100644 --- a/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml +++ b/examples/lqr/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml @@ -45,4 +45,6 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [1, 0.1, 1, 0.1, 0.1, 0.1] # Match LQR weights + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml b/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml index 8ebffccb7..878726f3f 100644 --- a/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml +++ b/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml @@ -68,4 +68,7 @@ task_config: episode_len_sec: 6 cost: quadratic + # Match LQR weights + rew_state_weight: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml b/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml index 4089e5359..42fab5a12 100644 --- a/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml +++ b/examples/lqr/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml @@ -71,4 +71,7 @@ task_config: episode_len_sec: 6 cost: quadratic + # Match LQR weights + rew_state_weight: [1, 0.1, 1, 0.1, 1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + rew_act_weight: [0.1] done_on_out_of_bound: True diff --git a/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml b/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml index b942acb4a..4e800fa31 100644 --- a/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml +++ b/examples/mpc/config_overrides/cartpole/cartpole_stabilization.yaml @@ -30,6 +30,8 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights + rew_act_weight: [0.1] done_on_out_of_bound: True constraints: diff --git a/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml b/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml index 37c07aa09..cd4164f41 100644 --- a/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml +++ b/examples/mpc/config_overrides/cartpole/cartpole_tracking.yaml @@ -33,6 +33,8 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [5.0, 0.1, 5.0, 0.1] # Match MPC weights + rew_act_weight: [0.1] done_on_out_of_bound: True constraints: diff --git a/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml index ee1853730..3786a1db4 100644 --- a/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml +++ b/examples/mpc/config_overrides/quadrotor_2D/linear_mpc_quadrotor_2D_tracking.yaml @@ -5,9 +5,9 @@ algo_config: - 0.1 - 0.1 q_mpc: - - 1.0 + - 5.0 - 0.1 - - 1.0 + - 5.0 - 0.1 - 0.1 - 0.1 diff --git a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml index 494c9aefa..8d380fc11 100644 --- a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml +++ b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_stabilization.yaml @@ -42,6 +42,8 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights + rew_act_weight: [0.1, 0.1] done_on_out_of_bound: True constraints: diff --git a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml index 2405a2238..e3eb0959b 100644 --- a/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml +++ b/examples/mpc/config_overrides/quadrotor_2D/quadrotor_2D_tracking.yaml @@ -45,6 +45,8 @@ task_config: episode_len_sec: 6 cost: quadratic + rew_state_weight: [5.0, 0.1, 5.0, 0.1, 0.1, 0.1] # Match MPC weights + rew_act_weight: [0.1, 0.1] done_on_out_of_bound: True constraints: diff --git a/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml index 1102acce3..009f561b2 100644 --- a/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml +++ b/examples/mpc/config_overrides/quadrotor_3D/linear_mpc_quadrotor_3D_tracking.yaml @@ -7,11 +7,11 @@ algo_config: - 0.1 - 0.1 q_mpc: - - 1.0 + - 5.0 - 0.1 - - 1.0 + - 5.0 - 0.1 - - 1.0 + - 5.0 - 0.1 - 0.1 - 0.1 diff --git a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml index 4c652d130..305cda6eb 100644 --- a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml +++ b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_stabilization.yaml @@ -68,6 +68,9 @@ task_config: episode_len_sec: 6 cost: quadratic + # Match MPC weights + rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + rew_act_weight: [0.1, 0.1, 0.1, 0.1] done_on_out_of_bound: True constraints: - constraint_form: default_constraint diff --git a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml index e4263c249..1890b5251 100644 --- a/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml +++ b/examples/mpc/config_overrides/quadrotor_3D/quadrotor_3D_tracking.yaml @@ -70,6 +70,9 @@ task_config: proj_normal: [0, 1, 1] episode_len_sec: 6 cost: quadratic + # Match MPC weights + rew_state_weight: [5.0, 0.1, 5.0, 0.1, 5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + rew_act_weight: [0.1, 0.1, 0.1, 0.1] done_on_out_of_bound: True constraints: - constraint_form: default_constraint diff --git a/safe_control_gym/controllers/lqr/ilqr.py b/safe_control_gym/controllers/lqr/ilqr.py index 3ebc484c7..7407a923f 100644 --- a/safe_control_gym/controllers/lqr/ilqr.py +++ b/safe_control_gym/controllers/lqr/ilqr.py @@ -64,7 +64,6 @@ def __init__( self.model = self.get_prior(self.env) self.Q = get_cost_weight_matrix(self.q_lqr, self.model.nx) self.R = get_cost_weight_matrix(self.r_lqr, self.model.nu) - self.env.set_cost_function_param(self.Q, self.R) self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ, self.Q, self.R, self.discrete_dynamics) diff --git a/safe_control_gym/controllers/lqr/lqr.py b/safe_control_gym/controllers/lqr/lqr.py index 5e9597d51..f03069525 100644 --- a/safe_control_gym/controllers/lqr/lqr.py +++ b/safe_control_gym/controllers/lqr/lqr.py @@ -33,7 +33,6 @@ def __init__( self.discrete_dynamics = discrete_dynamics self.Q = get_cost_weight_matrix(q_lqr, self.model.nx) self.R = get_cost_weight_matrix(r_lqr, self.model.nu) - self.env.set_cost_function_param(self.Q, self.R) self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ, self.Q, self.R, self.discrete_dynamics) diff --git a/safe_control_gym/controllers/mpc/mpc.py b/safe_control_gym/controllers/mpc/mpc.py index ac2ed4e59..8144638e1 100644 --- a/safe_control_gym/controllers/mpc/mpc.py +++ b/safe_control_gym/controllers/mpc/mpc.py @@ -8,8 +8,8 @@ from safe_control_gym.controllers.base_controller import BaseController from safe_control_gym.controllers.lqr.lqr_utils import discretize_linear_system from safe_control_gym.controllers.mpc.mpc_utils import (compute_discrete_lqr_gain_from_cont_linear_system, - compute_state_rmse, get_cost_weight_matrix, - reset_constraints, rk_discrete) + get_cost_weight_matrix, reset_constraints, + rk_discrete) from safe_control_gym.envs.benchmark_env import Task from safe_control_gym.envs.constraints import GENERAL_CONSTRAINTS, create_constraint_list from safe_control_gym.utils.utils import timing @@ -431,106 +431,6 @@ def setup_results_dict(self): 't_wall': [] } - def run(self, - env=None, - render=False, - logging=False, - max_steps=None, - terminate_run_on_done=None - ): - '''Runs evaluation with current policy. - - Args: - render (bool): if to do real-time rendering. - logging (bool): if to log on terminal. - - Returns: - dict: evaluation statisitcs, rendered frames. - ''' - if env is None: - env = self.env - if terminate_run_on_done is None: - terminate_run_on_done = self.terminate_run_on_done - - self.x_prev = None - self.u_prev = None - if not env.initial_reset: - env.set_cost_function_param(self.Q, self.R) - # obs, info = env.reset() - obs = env.reset() - print('Init State:') - print(obs) - ep_returns, ep_lengths = [], [] - frames = [] - self.setup_results_dict() - self.results_dict['obs'].append(obs) - self.results_dict['state'].append(env.state) - i = 0 - if env.TASK == Task.STABILIZATION: - if max_steps is None: - MAX_STEPS = int(env.CTRL_FREQ * env.EPISODE_LEN_SEC) - else: - MAX_STEPS = max_steps - elif env.TASK == Task.TRAJ_TRACKING: - if max_steps is None: - MAX_STEPS = self.traj.shape[1] - else: - MAX_STEPS = max_steps - else: - raise Exception('Undefined Task') - self.terminate_loop = False - done = False - common_metric = 0 - while not (done and terminate_run_on_done) and i < MAX_STEPS and not (self.terminate_loop): - action = self.select_action(obs) - if self.terminate_loop: - print('Infeasible MPC Problem') - break - # Repeat input for more efficient control. - obs, reward, done, info = env.step(action) - self.results_dict['obs'].append(obs) - self.results_dict['reward'].append(reward) - self.results_dict['done'].append(done) - self.results_dict['info'].append(info) - self.results_dict['action'].append(action) - self.results_dict['state'].append(env.state) - self.results_dict['state_mse'].append(info['mse']) - self.results_dict['state_error'].append(env.state - env.X_GOAL[i, :]) - common_metric += info['mse'] - print(i, '-th step.') - print('action:', action) - print('obs', obs) - print('reward', reward) - print('done', done) - print(info) - print() - if render: - env.render() - frames.append(env.render('rgb_array')) - i += 1 - # Collect evaluation results. - ep_lengths = np.asarray(ep_lengths) - ep_returns = np.asarray(ep_returns) - if logging: - msg = '****** Evaluation ******\n' - msg += 'eval_ep_length {:.2f} +/- {:.2f} | eval_ep_return {:.3f} +/- {:.3f}\n'.format( - ep_lengths.mean(), ep_lengths.std(), ep_returns.mean(), - ep_returns.std()) - if len(frames) != 0: - self.results_dict['frames'] = frames - self.results_dict['obs'] = np.vstack(self.results_dict['obs']) - self.results_dict['state'] = np.vstack(self.results_dict['state']) - try: - self.results_dict['reward'] = np.vstack(self.results_dict['reward']) - self.results_dict['action'] = np.vstack(self.results_dict['action']) - self.results_dict['full_traj_common_cost'] = common_metric - self.results_dict['total_rmse_state_error'] = compute_state_rmse(self.results_dict['state']) - self.results_dict['total_rmse_obs_error'] = compute_state_rmse(self.results_dict['obs']) - except ValueError: - raise Exception('[ERROR] mpc.run().py: MPC could not find a solution for the first step given the initial conditions. ' - 'Check to make sure initial conditions are feasible.') - return deepcopy(self.results_dict) - def reset_before_run(self, obs, info=None, env=None): '''Reinitialize just the controller before a new run. diff --git a/safe_control_gym/envs/benchmark_env.py b/safe_control_gym/envs/benchmark_env.py index c60220974..61df37a61 100644 --- a/safe_control_gym/envs/benchmark_env.py +++ b/safe_control_gym/envs/benchmark_env.py @@ -176,10 +176,6 @@ def __init__(self, self.state_dim = self.state_space.shape[0] else: self.state_dim = self.obs_dim - # Default Q and R matrices for quadratic cost. - if self.COST == Cost.QUADRATIC: - self.Q = np.eye(self.observation_space.shape[0]) - self.R = np.eye(self.action_space.shape[0]) # Set constraint info. self.CONSTRAINTS = constraints self.DONE_ON_VIOLATION = done_on_violation @@ -221,25 +217,6 @@ def seed(self, disturbs.seed(self) return [seed] - def set_cost_function_param(self, - Q, - R - ): - '''Set the cost function parameters. - - Args: - Q (ndarray): State weight matrix (nx by nx). - R (ndarray): Input weight matrix (nu by nu). - ''' - - if not self.initial_reset: - self.Q = Q - self.R = R - else: - raise RuntimeError( - '[ERROR] env.set_cost_function_param() cannot be called after the first reset of the environment.' - ) - def set_adversary_control(self, action): '''Sets disturbance by an adversary controller, called before (each) step(). diff --git a/safe_control_gym/envs/gym_control/cartpole.py b/safe_control_gym/envs/gym_control/cartpole.py index 6a1c91317..7ed6ea442 100644 --- a/safe_control_gym/envs/gym_control/cartpole.py +++ b/safe_control_gym/envs/gym_control/cartpole.py @@ -150,7 +150,9 @@ def __init__(self, self.obs_goal_horizon = obs_goal_horizon self.obs_wrap_angle = obs_wrap_angle self.rew_state_weight = np.array(rew_state_weight, ndmin=1, dtype=float) + self.Q = np.diag(self.rew_state_weight) self.rew_act_weight = np.array(rew_act_weight, ndmin=1, dtype=float) + self.R = np.diag(self.rew_act_weight) self.rew_exponential = rew_exponential self.done_on_out_of_bound = done_on_out_of_bound # BenchmarkEnv constructor, called after defining the custom args, diff --git a/safe_control_gym/envs/gym_pybullet_drones/quadrotor.py b/safe_control_gym/envs/gym_pybullet_drones/quadrotor.py index 2e2cb1887..2619bdfe5 100644 --- a/safe_control_gym/envs/gym_pybullet_drones/quadrotor.py +++ b/safe_control_gym/envs/gym_pybullet_drones/quadrotor.py @@ -178,7 +178,9 @@ def __init__(self, self.norm_act_scale = norm_act_scale self.obs_goal_horizon = obs_goal_horizon self.rew_state_weight = np.array(rew_state_weight, ndmin=1, dtype=float) + self.Q = np.diag(self.rew_state_weight) self.rew_act_weight = np.array(rew_act_weight, ndmin=1, dtype=float) + self.R = np.diag(self.rew_act_weight) self.rew_exponential = rew_exponential self.done_on_out_of_bound = done_on_out_of_bound if info_mse_metric_state_weight is None: