Skip to content

Commit

Permalink
Re-adding the MPC run function as it is necessary for GP-MPC
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Dec 2, 2024
1 parent 44c38bb commit 3313d2a
Showing 1 changed file with 99 additions and 2 deletions.
101 changes: 99 additions & 2 deletions safe_control_gym/controllers/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.lqr.lqr_utils import discretize_linear_system
from safe_control_gym.controllers.mpc.mpc_utils import (compute_discrete_lqr_gain_from_cont_linear_system,
get_cost_weight_matrix, reset_constraints,
rk_discrete)
compute_state_rmse, get_cost_weight_matrix,
reset_constraints, rk_discrete)
from safe_control_gym.envs.benchmark_env import Task
from safe_control_gym.envs.constraints import GENERAL_CONSTRAINTS, create_constraint_list
from safe_control_gym.utils.utils import timing
Expand Down Expand Up @@ -431,6 +431,103 @@ def setup_results_dict(self):
't_wall': []
}

def run(self,
env=None,
render=False,
logging=False,
max_steps=None,
terminate_run_on_done=None
):
'''Runs evaluation with current policy.
Args:
render (bool): if to do real-time rendering.
logging (bool): if to log on terminal.
Returns:
dict: evaluation statisitcs, rendered frames.
'''
if env is None:
env = self.env
if terminate_run_on_done is None:
terminate_run_on_done = self.terminate_run_on_done

self.x_prev = None
self.u_prev = None
obs, info = env.reset()
print('Init State:')
print(obs)
ep_returns, ep_lengths = [], []
frames = []
self.setup_results_dict()
self.results_dict['obs'].append(obs)
self.results_dict['state'].append(env.state)
i = 0
if env.TASK == Task.STABILIZATION:
if max_steps is None:
MAX_STEPS = int(env.CTRL_FREQ * env.EPISODE_LEN_SEC)
else:
MAX_STEPS = max_steps
elif env.TASK == Task.TRAJ_TRACKING:
if max_steps is None:
MAX_STEPS = self.traj.shape[1]
else:
MAX_STEPS = max_steps
else:
raise Exception('Undefined Task')
self.terminate_loop = False
done = False
common_metric = 0
while not (done and terminate_run_on_done) and i < MAX_STEPS and not (self.terminate_loop):
action = self.select_action(obs)
if self.terminate_loop:
print('Infeasible MPC Problem')
break
# Repeat input for more efficient control.
obs, reward, done, info = env.step(action)
self.results_dict['obs'].append(obs)
self.results_dict['reward'].append(reward)
self.results_dict['done'].append(done)
self.results_dict['info'].append(info)
self.results_dict['action'].append(action)
self.results_dict['state'].append(env.state)
self.results_dict['state_mse'].append(info['mse'])
self.results_dict['state_error'].append(env.state - env.X_GOAL[i, :])
common_metric += info['mse']
print(i, '-th step.')
print('action:', action)
print('obs', obs)
print('reward', reward)
print('done', done)
print(info)
print()
if render:
env.render()
frames.append(env.render('rgb_array'))
i += 1
# Collect evaluation results.
ep_lengths = np.asarray(ep_lengths)
ep_returns = np.asarray(ep_returns)
if logging:
msg = '****** Evaluation ******\n'
msg += 'eval_ep_length {:.2f} +/- {:.2f} | eval_ep_return {:.3f} +/- {:.3f}\n'.format(
ep_lengths.mean(), ep_lengths.std(), ep_returns.mean(),
ep_returns.std())
if len(frames) != 0:
self.results_dict['frames'] = frames
self.results_dict['obs'] = np.vstack(self.results_dict['obs'])
self.results_dict['state'] = np.vstack(self.results_dict['state'])
try:
self.results_dict['reward'] = np.vstack(self.results_dict['reward'])
self.results_dict['action'] = np.vstack(self.results_dict['action'])
self.results_dict['full_traj_common_cost'] = common_metric
self.results_dict['total_rmse_state_error'] = compute_state_rmse(self.results_dict['state'])
self.results_dict['total_rmse_obs_error'] = compute_state_rmse(self.results_dict['obs'])
except ValueError:
raise Exception('[ERROR] mpc.run().py: MPC could not find a solution for the first step given the initial conditions. '
'Check to make sure initial conditions are feasible.')
return deepcopy(self.results_dict)

def reset_before_run(self, obs, info=None, env=None):
'''Reinitialize just the controller before a new run.
Expand Down

0 comments on commit 3313d2a

Please sign in to comment.