Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed May 15, 2024
1 parent 53143d2 commit 99142ae
Show file tree
Hide file tree
Showing 37 changed files with 7 additions and 56 deletions.
2 changes: 1 addition & 1 deletion experiments/mpsc/mpsc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def determine_feasible_starting_points(num_points=100):
# Define arguments.
fac = ConfigFactory()
config = fac.merge()
config.sf_config.cost_function='one_step_cost'
config.sf_config.cost_function = 'one_step_cost'

task = 'stab' if config.task_config.task == Task.STABILIZATION else 'track'
if config.task == Environment.QUADROTOR:
Expand Down
55 changes: 3 additions & 52 deletions experiments/mpsc/plotting_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

plot = False
save_figs = True
ordered_algos = ['lqr', 'ppo', 'sac']
# ordered_algos = ['lqr', 'pid', 'ppo', 'sac']

U_EQs = {
'cartpole': 0,
Expand All @@ -27,51 +25,6 @@
met.verbose = False


def load_one_experiment(system, task, algo, mpsc_cost_horizon):
'''Loads the results of every MPSC cost function for a specific experiment.
Args:
system (str): The system to be controlled.
task (str): The task to be completed (either 'stab' or 'track').
algo (str): The controller being used.
mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
Returns:
all_results (dict): A dictionary containing all the results.
'''

all_results = {}

for cost in ordered_costs:
with open(f'./results_mpsc/{system}/{task}/m{mpsc_cost_horizon}/results_{system}_{task}_{algo}_{cost}_cost_m{mpsc_cost_horizon}.pkl', 'rb') as f:
all_results[cost] = pickle.load(f)

return all_results


def load_all_algos(system, task, mpsc_cost_horizon):
'''Loads the results of every MPSC cost function for a specific experiment with every algo.
Args:
system (str): The system to be controlled.
task (str): The task to be completed (either 'stab' or 'track').
mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
Returns:
all_results (dict): A dictionary containing all the results.
'''

all_results = {}

for algo in ['lqr', 'pid', 'ppo', 'sac']:
if system == 'cartpole' and algo == 'pid':
continue

all_results[algo] = load_one_experiment(system, task, algo, mpsc_cost_horizon)

return all_results


def load_all_models(system, task, algo):
'''Loads the results of every MPSC cost function for a specific experiment with every algo.
Expand Down Expand Up @@ -566,7 +519,7 @@ def plot_model_comparisons(system, task, algo, data_extractor):
ax.set_ylabel(ylabel, weight='bold', fontsize=45, labelpad=10)

x = np.arange(1, len(labels) + 1)
ax.set_xticks(x, labels, weight='bold', fontsize=15, rotation=45, ha='right')
ax.set_xticks(x, labels, weight='bold', fontsize=15, rotation=30, ha='right')

medianprops = dict(linestyle='--', linewidth=2.5, color='black')
bplot = ax.boxplot(data, patch_artist=True, labels=labels, medianprops=medianprops, widths=[0.75] * len(labels))
Expand Down Expand Up @@ -629,9 +582,6 @@ def plot_all_logs(system, task, algo):
for seed in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/{model}/'):
all_results[model].append(load_from_logs(f'./models/rl_models/{system}/{task}/{algo}/{model}/{seed}/logs/'))

# all_results['safe_ppo'] = load_from_logs(f'./models/rl_models/{system}/{task}/safe_explorer_ppo/none/logs/')
# all_results['cpo'] = load_from_logs(f'./models/rl_models/{system}/{task}/cpo/none/logs/')

for key in all_results['none'][0].keys():
plot_log(system, task, algo, key, all_results)

Expand All @@ -655,12 +605,13 @@ def plot_log(system, task, algo, key, all_results):
colors = {'mpsf_sr_pen_1': 'lightgreen', 'mpsf_sr_pen_10': 'limegreen', 'mpsf_sr_pen_100': 'forestgreen', 'mpsf_sr_pen_1000': 'darkgreen', 'none': 'cornflowerblue', 'none_cpen': 'plum'}

for model in labels:
x = all_results[model][0][key][1]
x = all_results[model][0][key][1] / 1000
all_data = np.array([values[key][3] for values in all_results[model]])
ax.plot(x, np.mean(all_data, axis=0), label=model, color=colors[model])
ax.fill_between(x, np.min(all_data, axis=0), np.max(all_data, axis=0), alpha=0.3, edgecolor=colors[model], facecolor=colors[model])

ax.set_ylabel(key, weight='bold', fontsize=45, labelpad=10)
ax.set_xlabel('Training Episodes')
ax.legend()

fig.tight_layout()
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion experiments/mpsc/train_rl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
'''Template training/plotting/testing script.'''

from functools import partial
import os
import shutil
import time
from functools import partial

import munch
import yaml
Expand Down
2 changes: 1 addition & 1 deletion safe_control_gym/controllers/cpo/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def get_cost(self, info):
state_constraints = np.maximum(info['constraint_values'][:nx], info['constraint_values'][nx:nx * 2])
constraint_width = info['constraint_values'][:nx] + info['constraint_values'][nx:nx * 2]
state_cost = np.divide(state_constraints, -constraint_width / 2) + 0.0001
return np.sum([max(s,0) for s in state_cost])
return np.sum([max(s, 0) for s in state_cost])

def normalizeAction(self, a):
return normalize(a, self.action_bound_max, self.action_bound_min)
Expand Down
2 changes: 1 addition & 1 deletion safe_control_gym/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import logging
import os
from collections import defaultdict
import time
from collections import defaultdict

import imageio
import numpy as np
Expand Down

0 comments on commit 99142ae

Please sign in to comment.