Skip to content

Commit

Permalink
Fixing decay factor decrease
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Nov 18, 2024
1 parent 162075e commit 21f621a
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 42 deletions.
8 changes: 4 additions & 4 deletions experiments/mpsc/mpsc_experiment.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash

SYS='cartpole'
# SYS='cartpole'
# SYS='quadrotor_2D'
# SYS='quadrotor_3D'
SYS='quadrotor_3D'

TASK='stab'
# TASK='track'
# TASK='stab'
TASK='track'

# ALGO='lqr'
# ALGO='pid'
Expand Down
37 changes: 11 additions & 26 deletions experiments/mpsc/plotting_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def load_all_models(system, task, algo):

for model in ordered_models:
all_results[model] = []
for seed in os.listdir(f'./results_mpsc/{system}/{task}/{algo}/results_{system}_{task}_{algo}_{model}/'):
with open(f'./results_mpsc/{system}/{task}/{algo}/results_{system}_{task}_{algo}_{model}/{seed}', 'rb') as f:
all_results[model].append(pickle.load(f))
with open(f'./results_mpsc/{model}.pkl', 'rb') as f:
all_results[model].append(pickle.load(f))
consolidate_multiple_seeds(all_results, model)

return all_results
Expand Down Expand Up @@ -525,7 +524,7 @@ def plot_model_comparisons(system, task, algo, data_extractor):
medianprops = dict(linestyle='--', linewidth=2.5, color='black')
bplot = ax.boxplot(data, patch_artist=True, labels=labels, medianprops=medianprops, widths=[0.75] * len(labels), showfliers=False)

for patch, color in zip(bplot['boxes'], colors.values()):
for patch, color in zip(bplot['boxes'], colors):
patch.set_facecolor(color)

fig.tight_layout()
Expand All @@ -552,8 +551,7 @@ def plot_step_time(system, task, algo):
all_results = {}
for model in ordered_models:
all_results[model] = []
for seed in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/{model}/'):
all_results[model].append(load_from_logs(f'./models/rl_models/{system}/{task}/{algo}/{model}/{seed}/logs/'))
all_results[model].append(load_from_logs(f'./models/rl_models/{model}/logs/'))

fig = plt.figure(figsize=(16.0, 10.0))
ax = fig.add_subplot(111)
Expand All @@ -575,7 +573,7 @@ def plot_step_time(system, task, algo):
medianprops = dict(linestyle='--', linewidth=2.5, color='black')
bplot = ax.boxplot(data, patch_artist=True, labels=labels, medianprops=medianprops, widths=[0.75] * len(labels), showfliers=False)

for patch, color in zip(bplot['boxes'], colors.values()):
for patch, color in zip(bplot['boxes'], colors):
patch.set_facecolor(color)

fig.tight_layout()
Expand Down Expand Up @@ -627,8 +625,7 @@ def plot_all_logs(system, task, algo):

for model in ordered_models:
all_results[model] = []
for seed in os.listdir(f'./models/rl_models/{system}/{task}/{algo}/{model}/'):
all_results[model].append(load_from_logs(f'./models/rl_models/{system}/{task}/{algo}/{model}/{seed}/logs/'))
all_results[model].append(load_from_logs(f'./models/rl_models/{model}/logs/'))

for key in all_results[ordered_models[0]][0].keys():
if key == 'stat_eval/ep_return':
Expand All @@ -647,13 +644,11 @@ def plot_log(key, all_results):
fig = plt.figure(figsize=(16.0, 10.0))
ax = fig.add_subplot(111)

labels = ordered_models

for model, label in zip(ordered_models, labels):
for index, model in enumerate(ordered_models):
x = all_results[model][0][key][1] / 1000
all_data = np.array([values[key][3] for values in all_results[model]])
ax.plot(x, np.mean(all_data, axis=0), label=label, color=colors[model])
ax.fill_between(x, np.min(all_data, axis=0), np.max(all_data, axis=0), alpha=0.3, edgecolor=colors[model], facecolor=colors[model])
ax.plot(x, np.mean(all_data, axis=0), label=model, color=colors[index])
# ax.fill_between(x, np.min(all_data, axis=0), np.max(all_data, axis=0), alpha=0.3, edgecolor=colors[index], facecolor=colors[index])

ax.set_ylabel(key, weight='bold', fontsize=45, labelpad=10)
ax.set_xlabel('Training Episodes')
Expand All @@ -671,18 +666,8 @@ def plot_log(key, all_results):


if __name__ == '__main__':
ordered_models = ['none', 'none_cpen_0.01', 'none_cpen_0.1', 'none_cpen_1', 'mpsf_sr_pen_0.1', 'mpsf_sr_pen_1', 'mpsf_sr_pen_10', 'mpsf_sr_pen_100']

colors = {
'none': 'cornflowerblue',
'none_cpen_0.01': 'plum',
'none_cpen_0.1': 'mediumorchid',
'none_cpen_1': 'darkorchid',
'mpsf_sr_pen_0.1': 'lightgreen',
'mpsf_sr_pen_1': 'limegreen',
'mpsf_sr_pen_10': 'forestgreen',
'mpsf_sr_pen_100': 'darkgreen',
}
ordered_models = [model for model in os.listdir('./models/rl_models/') if 'curriculum' in model]
colors = plt.cm.viridis(np.linspace(0, 1, len(ordered_models)))

def extract_rate_of_change_of_inputs(results_data, certified=True):
return extract_rate_of_change(results_data, certified, order=1, mode='input')
Expand Down
2 changes: 1 addition & 1 deletion experiments/mpsc/train_all_models.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
sbatch train_model.sbatch True 1 1
sbatch train_model.sbatch False 1 1
for MPSC_COST_HORIZON in 2 5 10 20; do
for DECAY_FACTOR in 0.25 0.5 0.75 1; do
sbatch train_model.sbatch True $MPSC_COST_HORIZON $DECAY_FACTOR
Expand Down
9 changes: 5 additions & 4 deletions experiments/mpsc/train_model.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ python3 train_rl.py \
sf_config.decay_factor=${DECAY_FACTOR} \
sf_config.max_decay_factor=${DECAY_FACTOR} \
sf_config.soften_constraints=True \
algo_config.filter_train_actions=${FILTER} \
algo_config.use_safe_reset=${FILTER} \
algo_config.penalize_sf_diff=${FILTER} \
algo_config.sf_penalty=SF_PENALTY
algo_config.filter_train_actions=True \
algo_config.use_safe_reset=True \
algo_config.penalize_sf_diff=True \
algo_config.sf_penalty=$SF_PENALTY \
algo_config.decay_factor_curriculum=$1

./mpsc_experiment.sh $TAG $SYS $TASK $ALGO
3 changes: 2 additions & 1 deletion safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def learn(self,
):
'''Performs learning (pre-training, training, fine-tuning, etc).'''
while self.total_steps < self.max_env_steps:
self.safety_filter.decay_factor = self.safety_filter.max_decay_factor * (self.total_steps / self.max_env_steps)
if self.decay_factor_curriculum:
self.safety_filter.set_decay_factor(self.safety_filter.max_decay_factor * (self.total_steps / self.max_env_steps))
results = self.train_step()
# Checkpoint.
if self.total_steps >= self.max_env_steps or (self.save_interval and self.total_steps % self.save_interval == 0):
Expand Down
16 changes: 10 additions & 6 deletions safe_control_gym/safety_filters/mpsc/nl_mpsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,12 +1040,6 @@ def setup_acados_optimizer(self):
solver_json = 'acados_ocp_mpsf.json'
ocp_solver = AcadosOcpSolver(ocp, json_file=solver_json, generate=True, build=True)

for stage in range(self.mpsc_cost_horizon):
ocp_solver.cost_set(stage, 'W', (self.cost_function.decay_factor**stage) * ocp.cost.W)

for stage in range(self.mpsc_cost_horizon, self.horizon):
ocp_solver.cost_set(stage, 'W', 0 * ocp.cost.W)

s_var = np.zeros((self.horizon + 1))
g = np.zeros((self.horizon, self.p))

Expand All @@ -1057,4 +1051,14 @@ def setup_acados_optimizer(self):
g[i, :] += (self.L_x @ self.X_mid) + (self.L_u @ self.U_mid)
ocp_solver.constraints_set(i, 'ug', g[i, :])

self.ocp = ocp
self.ocp_solver = ocp_solver

self.set_decay_factor(self.cost_function.decay_factor)

def set_decay_factor(self, new_decay_factor):
for stage in range(self.mpsc_cost_horizon):
self.ocp_solver.cost_set(stage, 'W', (new_decay_factor**stage) * self.ocp.cost.W)

for stage in range(self.mpsc_cost_horizon, self.horizon):
self.ocp_solver.cost_set(stage, 'W', 0 * self.ocp.cost.W)

0 comments on commit 21f621a

Please sign in to comment.