Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Dec 11, 2024
1 parent 8f94570 commit 7c5b194
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
4 changes: 0 additions & 4 deletions experiments/mpsc/mpsc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def run(plot=False, model='ppo'):
config.task_config['done_on_violation'] = False
config.task_config['randomized_init'] = False

system = 'quadrotor_2D_attitude'

# Create an environment
env_func = partial(make,
config.task,
Expand Down Expand Up @@ -76,8 +74,6 @@ def run(plot=False, model='ppo'):
safety_filter.cost_function.uncertified_controller = ctrl
safety_filter.cost_function.output_dir = '.'

safety_filter.load(path=f'./models/mpsc_parameters/{config.safety_filter}_{system}.pkl')

# Run with safety filter
experiment = BaseExperiment(env, ctrl, safety_filter=safety_filter)
cert_results, cert_metrics = experiment.run_evaluation(n_episodes=1)
Expand Down
9 changes: 4 additions & 5 deletions experiments/mpsc/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import munch
import yaml

from safe_control_gym.experiments.base_experiment import BaseExperiment
from safe_control_gym.safety_filters.mpsc.mpsc_utils import Cost_Function
from safe_control_gym.utils.configuration import ConfigFactory
from safe_control_gym.utils.plotting import plot_from_logs
Expand All @@ -27,8 +28,6 @@ def train():

shutil.rmtree(config.output_dir, ignore_errors=True)

system = 'quadrotor_2D_attitude'

set_seed_from_config(config)
set_device_from_config(config)

Expand All @@ -38,6 +37,7 @@ def train():
output_dir=config.output_dir,
**config.task_config
)
env = env_func()

# Create the controller/control_agent.
ctrl = make(config.algo,
Expand All @@ -58,13 +58,12 @@ def train():
safety_filter.cost_function.uncertified_controller = ctrl
safety_filter.cost_function.output_dir = '.'

safety_filter.load(path=f'./models/mpsc_parameters/{config.safety_filter}_{system}.pkl')

ctrl.safety_filter = safety_filter

# Training.
start_time = time.time()
ctrl.learn()
experiment = BaseExperiment(env, ctrl, safety_filter=safety_filter)
experiment.launch_training()
config['logging'] = {'total_learning_time': time.time() - start_time}
ctrl.close()
print('Training done.')
Expand Down

0 comments on commit 7c5b194

Please sign in to comment.