diff --git a/experiments/mpsc/mpsc_experiment.py b/experiments/mpsc/mpsc_experiment.py index c5b2d772a..a25cda931 100644 --- a/experiments/mpsc/mpsc_experiment.py +++ b/experiments/mpsc/mpsc_experiment.py @@ -6,7 +6,7 @@ import numpy as np -from safe_control_gym.experiments.base_experiment import BaseExperiment +from safe_control_gym.experiments.base_experiment import BaseExperiment, MetricExtractor from safe_control_gym.utils.registration import make from safe_control_gym.utils.configuration import ConfigFactory from safe_control_gym.envs.benchmark_env import Task, Cost, Environment @@ -426,8 +426,9 @@ def run_multiple(plot=True): all_uncert_results[key].append(uncert_results[key][0]) all_cert_results[key].append(cert_results[key][0]) - uncert_metrics = BaseExperiment.compute_metrics(all_uncert_results) - cert_metrics = BaseExperiment.compute_metrics(all_cert_results) + met = MetricExtractor() + uncert_metrics = met.compute_metrics(data=all_uncert_results) + cert_metrics = met.compute_metrics(data=all_cert_results) all_results = {'uncert_results': all_uncert_results, 'uncert_metrics': uncert_metrics, diff --git a/safe_control_gym/safety_filters/cbf/cbf.py b/safe_control_gym/safety_filters/cbf/cbf.py index 172853bd7..eddd8329e 100644 --- a/safe_control_gym/safety_filters/cbf/cbf.py +++ b/safe_control_gym/safety_filters/cbf/cbf.py @@ -232,9 +232,11 @@ def certify_action(self, success (bool): Whether the safety filtering was successful or not. ''' + uncertified_action = np.clip(uncertified_action, self.env.physical_action_bounds[0], self.env.physical_action_bounds[1]) self.results_dict['uncertified_action'].append(uncertified_action) certified_action, success = self.solve_optimization(current_state, uncertified_action) self.results_dict['feasible'].append(success) + certified_action = np.squeeze(np.array(certified_action)) self.results_dict['certified_action'].append(certified_action) self.results_dict['correction'].append(np.linalg.norm(certified_action - uncertified_action)) diff --git a/safe_control_gym/safety_filters/mpsc/mpsc.py b/safe_control_gym/safety_filters/mpsc/mpsc.py index 5f21f6690..f29809bab 100644 --- a/safe_control_gym/safety_filters/mpsc/mpsc.py +++ b/safe_control_gym/safety_filters/mpsc/mpsc.py @@ -246,6 +246,7 @@ def certify_action(self, success = False certified_action = clipped_action + certified_action = np.squeeze(np.array(certified_action)) self.results_dict['kinf'].append(self.kinf) self.results_dict['certified_action'].append(certified_action) self.results_dict['correction'].append(np.linalg.norm(certified_action - uncertified_action))