diff --git a/experiments/mpsc/mpsc_experiment.py b/experiments/mpsc/mpsc_experiment.py index 2cfb0d937..4931803ed 100644 --- a/experiments/mpsc/mpsc_experiment.py +++ b/experiments/mpsc/mpsc_experiment.py @@ -31,8 +31,6 @@ def run(plot=False, model='ppo'): config.task_config['done_on_violation'] = False config.task_config['randomized_init'] = False - system = 'quadrotor_2D_attitude' - # Create an environment env_func = partial(make, config.task, @@ -76,8 +74,6 @@ def run(plot=False, model='ppo'): safety_filter.cost_function.uncertified_controller = ctrl safety_filter.cost_function.output_dir = '.' - safety_filter.load(path=f'./models/mpsc_parameters/{config.safety_filter}_{system}.pkl') - # Run with safety filter experiment = BaseExperiment(env, ctrl, safety_filter=safety_filter) cert_results, cert_metrics = experiment.run_evaluation(n_episodes=1) diff --git a/experiments/mpsc/train_rl.py b/experiments/mpsc/train_rl.py index a3f79014a..ec3db251f 100644 --- a/experiments/mpsc/train_rl.py +++ b/experiments/mpsc/train_rl.py @@ -8,6 +8,7 @@ import munch import yaml +from safe_control_gym.experiments.base_experiment import BaseExperiment from safe_control_gym.safety_filters.mpsc.mpsc_utils import Cost_Function from safe_control_gym.utils.configuration import ConfigFactory from safe_control_gym.utils.plotting import plot_from_logs @@ -27,8 +28,6 @@ def train(): shutil.rmtree(config.output_dir, ignore_errors=True) - system = 'quadrotor_2D_attitude' - set_seed_from_config(config) set_device_from_config(config) @@ -38,6 +37,7 @@ def train(): output_dir=config.output_dir, **config.task_config ) + env = env_func() # Create the controller/control_agent. ctrl = make(config.algo, @@ -58,13 +58,12 @@ def train(): safety_filter.cost_function.uncertified_controller = ctrl safety_filter.cost_function.output_dir = '.' - safety_filter.load(path=f'./models/mpsc_parameters/{config.safety_filter}_{system}.pkl') - ctrl.safety_filter = safety_filter # Training. start_time = time.time() - ctrl.learn() + experiment = BaseExperiment(env, ctrl, safety_filter=safety_filter) + experiment.launch_training() config['logging'] = {'total_learning_time': time.time() - start_time} ctrl.close() print('Training done.')