Skip to content

Commit

Permalink
Trained each approach 5 times with 5 seeds
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Jun 20, 2024
1 parent 25bf714 commit cd27a45
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 62 deletions.
89 changes: 49 additions & 40 deletions experiments/crazyflie/plotting_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@

import matplotlib.pyplot as plt
import numpy as np
# import tikzplotlib
# from matplotlib.legend import Legend
# from matplotlib.lines import Line2D

# Line2D._us_dashSeq = property(lambda self: self._dash_pattern[1])
# Line2D._us_dashOffset = property(lambda self: self._dash_pattern[0])
# Legend._ncol = property(lambda self: self._ncols)
import tikzplotlib
from matplotlib.legend import Legend
from matplotlib.lines import Line2D

from safe_control_gym.safety_filters.mpsc.mpsc_utils import get_discrete_derivative
from safe_control_gym.utils.plotting import load_from_logs

Line2D._us_dashSeq = property(lambda self: self._dash_pattern[1])
Line2D._us_dashOffset = property(lambda self: self._dash_pattern[0])
Legend._ncol = property(lambda self: self._ncols)


plot = False
save_figs = True
Expand Down Expand Up @@ -72,13 +72,13 @@ def load_all_models(algo):

model_data['uncertified_action'].append(model_data['certified_action'][-1] - model_data['corrections'][-1])

error = model_data['states'][-1][:, [0,2]] - traj_goal[:, [0,2]]
error = model_data['states'][-1][:, [0, 2]] - traj_goal[:, [0, 2]]
dist = np.sum(2 * error * error, axis=1)
reward = np.sum(np.exp(-dist))
model_data['rewards'].append(reward)

# TODO fix this
constr_viols = np.sum(np.sum(np.abs(model_data['states'][-1][:, [0,1,2,3,6,7]]) > np.array([[0.95, 2, 0.95, 2, 0.25, 0.25]]), axis=1) > 0)
constr_viols = np.sum(np.sum(np.abs(model_data['states'][-1][:, [0, 1, 2, 3, 6, 7]]) > np.array([[0.95, 2, 0.95, 2, 0.25, 0.25]]), axis=1) > 0)
model_data['constraint_violations'].append(constr_viols)

model_data['length'].append(len(model_data['states'][-1]))
Expand Down Expand Up @@ -224,6 +224,7 @@ def extract_constraint_violations(results_data):
num_violations = np.asarray(results_data['constraint_violations'])
return num_violations


def extract_rate_of_change(results_data):
'''Extracts the rate of change of a signal from an experiment's data.
Expand All @@ -244,6 +245,7 @@ def extract_rate_of_change(results_data):

return total_derivatives


def extract_reward(results_data):
'''Extracts the mean reward from an experiment's data.
Expand Down Expand Up @@ -318,17 +320,15 @@ def plot_all_logs(algo):
'''Plots comparative plots of all the logs.
Args:
system (str): The system to be controlled.
task (str): The task to be completed (either 'stab' or 'track').
mpsc_cost_horizon (str): The cost horizon used by the smooth MPSC cost functions.
'''
all_results = {}

for model in ordered_models:
all_results[model] = [load_from_logs(f'./models/rl_models/{algo}/{model}/logs/')]

# all_results['safe_ppo'] = load_from_logs(f'./models/rl_models/safe_explorer_ppo/none/logs/')
# all_results['cpo'] = load_from_logs(f'./models/rl_models/cpo/none/logs/')
all_results[model] = []
for seed in os.listdir(f'./models/rl_models/{algo}/{model}/'):
if 'seed' in seed:
all_results[model].append(load_from_logs(f'./models/rl_models/{algo}/{model}/{seed}/logs/'))

for key in all_results[ordered_models[0]][0].keys():
plot_log(algo, key, all_results)
Expand All @@ -345,22 +345,30 @@ def plot_log(algo, key, all_results):
fig = plt.figure(figsize=(16.0, 10.0))
ax = fig.add_subplot(111)

labels = sorted(all_results.keys())
labels = [label for label in labels if '_es' not in label]
labels = [f'mpsf_0.1{suffix}', f'mpsf_1{suffix}', f'mpsf_10{suffix}', f'none{suffix}', f'none_cpen{suffix}']

colors = plt.colormaps['tab20'].colors
colors = {f'mpsf_0.1{suffix}': 'limegreen', f'mpsf_1{suffix}': 'forestgreen', f'mpsf_10{suffix}': 'darkgreen', f'none{suffix}': 'cornflowerblue', f'none_cpen{suffix}': 'plum'}

for i, model in enumerate(labels):
if key == 'loss/critic_loss' and model == 'safe_ppo':
continue
if key in ['loss/policy_loss', 'loss/critic_loss'] and model == 'cpo':
continue
x = all_results[model][0][key][1]
all_data = np.array([values[key][3] for values in all_results[model]])
ax.plot(x, np.mean(all_data, axis=0), label=model, color=colors[i])
ax.fill_between(x, np.min(all_data, axis=0), np.max(all_data, axis=0), alpha=0.3, edgecolor=colors[i], facecolor=colors[i])
names = {
f'mpsf_0.1{suffix}': r'\textbf{Ours \$\boldsymbol{\alpha=10^{-1}}\$}',
f'mpsf_1{suffix}': r'\textbf{Ours \$\boldsymbol{\alpha=10^0}\$}',
f'mpsf_10{suffix}': r'\textbf{Ours \$\boldsymbol{\alpha=10^1}\$}',
f'none{suffix}': 'Standard',
f'none_cpen{suffix}': 'Constr. Pen.'
}

ax.set_ylabel(key, weight='bold', fontsize=45, labelpad=10)
for model in labels:
x = all_results[model][0][key][1] / 1000
all_data = np.array([values[key][3] for values in all_results[model]])
ax.plot(x, np.mean(all_data, axis=0), label=names[model], color=colors[model])
ax.fill_between(x, np.min(all_data, axis=0), np.max(all_data, axis=0), alpha=0.3, edgecolor=colors[model], facecolor=colors[model])

ylabels = {
'stat_eval/ep_return': 'Episode Return',
'stat/ep_constraint_violation': 'Episode Constraint Violations'
}
ax.set_ylabel(ylabels[key] if key in ylabels else key, weight='bold', fontsize=45, labelpad=10)
ax.set_xlabel('Training Episodes')
ax.legend()

fig.tight_layout()
Expand All @@ -374,26 +382,27 @@ def plot_log(algo, key, all_results):
fig.savefig(f'./results_cf/{algo}/graphs/{suffix[1:]}/{image_suffix}.png', dpi=300)
else:
fig.savefig(f'./results_cf/{algo}/graphs/{image_suffix}.png', dpi=300)
tikzplotlib.save(f'./{image_suffix}.tex', axis_height='2.2in', axis_width='3.5in')
plt.close()


if __name__ == '__main__':
REAL = True
REAL = False
CERTIFIED = True
algo_name = 'ppo'
all_results = load_all_models(algo_name)

create_paper_plot(extract_magnitude_of_corrections)
create_paper_plot(extract_percent_magnitude_of_corrections)
create_paper_plot(extract_max_correction)
create_paper_plot(extract_percent_max_correction)
create_paper_plot(extract_rate_of_change)
create_paper_plot(extract_number_of_corrections)
create_paper_plot(extract_feasible_iterations)
create_paper_plot(extract_reward)
create_paper_plot(extract_rmse)
create_paper_plot(extract_constraint_violations)
create_paper_plot(extract_length)
# create_paper_plot(extract_magnitude_of_corrections)
# create_paper_plot(extract_percent_magnitude_of_corrections)
# create_paper_plot(extract_max_correction)
# create_paper_plot(extract_percent_max_correction)
# create_paper_plot(extract_rate_of_change)
# create_paper_plot(extract_number_of_corrections)
# create_paper_plot(extract_feasible_iterations)
# create_paper_plot(extract_reward)
# create_paper_plot(extract_rmse)
# create_paper_plot(extract_constraint_violations)
# create_paper_plot(extract_length)

if not REAL:
plot_all_logs(algo_name)
22 changes: 17 additions & 5 deletions experiments/crazyflie/test_crazyflie.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ else
SF_PEN_TAG="_$4"
fi

TAG="$1${CONSTR_PEN_TAG}${SF_PEN_TAG}_dm"
if [ -z "$5" ]; then
SEED=1337
else
SEED=$5
fi

TAG="$1${CONSTR_PEN_TAG}${SF_PEN_TAG}"
echo $TAG $SYS $ALGO $TASK

python3 ./train_rl.py \
Expand All @@ -50,14 +56,17 @@ python3 ./train_rl.py \
./config_overrides/crazyflie_${TASK}.yaml \
./config_overrides/${ALGO}_crazyflie.yaml \
./config_overrides/nl_mpsc.yaml \
--output_dir ./models/rl_models/${ALGO}/${TAG} \
--output_dir ./models/rl_models/${ALGO}/${TAG}/seed_${SEED} \
--kv_overrides \
sf_config.cost_function=one_step_cost \
algo_config.filter_train_actions=$FILTER \
algo_config.penalize_sf_diff=$FILTER \
algo_config.use_safe_reset=$FILTER \
algo_config.sf_penalty=$4 \
task_config.use_constraint_penalty=$3
task_config.use_constraint_penalty=$3 \
task_config.seed=${SEED} \
algo_config.seed=${SEED} \
sf_config.seed=${SEED}

python3 ./crazyflie_experiment.py \
--task quadrotor \
Expand All @@ -67,12 +76,15 @@ python3 ./crazyflie_experiment.py \
./config_overrides/crazyflie_${TASK}.yaml \
./config_overrides/${ALGO}_crazyflie.yaml \
./config_overrides/nl_mpsc.yaml \
--output_dir ./models/rl_models/${ALGO}/${TAG} \
--output_dir ./models/rl_models/${ALGO}/${TAG}/seed_${SEED} \
--kv_overrides \
sf_config.cost_function=precomputed_cost \
sf_config.mpsc_cost_horizon=${MPSC_COST_HORIZON} \
algo_config.filter_train_actions=$FILTER \
algo_config.penalize_sf_diff=$FILTER \
algo_config.use_safe_reset=$FILTER \
algo_config.sf_penalty=$4 \
task_config.use_constraint_penalty=$3
task_config.use_constraint_penalty=$3 \
task_config.seed=${SEED} \
algo_config.seed=${SEED} \
sf_config.seed=${SEED}
12 changes: 7 additions & 5 deletions experiments/crazyflie/train_all_models.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/bin/bash
for ALGO in ppo; do
./test_crazyflie.sh mpsf $ALGO False 0.1 #mpsf_sr_pen_0.1
./test_crazyflie.sh mpsf $ALGO False 1 #mpsf_sr_pen_1
./test_crazyflie.sh mpsf $ALGO False 10 #mpsf_sr_pen_10
./test_crazyflie.sh none $ALGO False False #none
./test_crazyflie.sh none $ALGO True False #none_cpen
for SEED in 42 62 821 99 4077; do # 1102 1014 14 960406 2031; do
./test_crazyflie.sh mpsf $ALGO False 0.1 $SEED #mpsf_sr_pen_0.1
./test_crazyflie.sh mpsf $ALGO False 1 $SEED #mpsf_sr_pen_1
./test_crazyflie.sh mpsf $ALGO False 10 $SEED #mpsf_sr_pen_10
./test_crazyflie.sh none $ALGO False False $SEED #none
./test_crazyflie.sh none $ALGO True False $SEED #none_cpen
done
done
8 changes: 3 additions & 5 deletions experiments/crazyflie/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import os
import shutil
import yaml
import sys
sys.path.insert(0, '/home/federico/GitHub/safe-control-gym')

from functools import partial

import numpy as np
sys.path.insert(0, '/home/federico/GitHub/safe-control-gym')

import munch
import yaml

from experiments.crazyflie.crazyflie_utils import gen_traj
from safe_control_gym.utils.configuration import ConfigFactory
Expand Down Expand Up @@ -68,7 +67,6 @@ def train():
env_func,
checkpoint_path=os.path.join(config.output_dir, 'model_latest.pt'),
output_dir=config.output_dir,
seed=1,
**config.algo_config)

ctrl.firmware_wrapper = firmware_wrapper
Expand Down
8 changes: 1 addition & 7 deletions experiments/mpsc/plotting_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
from safe_control_gym.safety_filters.mpsc.mpsc_utils import get_discrete_derivative, high_frequency_content
from safe_control_gym.utils.plotting import load_from_logs

# from scipy.signal import savgol_filter


plot = False
save_figs = True
ordered_algos = ['lqr', 'ppo', 'sac']
Expand Down Expand Up @@ -785,8 +782,6 @@ def plot_model_comparisons(system, task, algo, data_extractor):
labels = labels + ['safe_ppo'] + ['safe_ppo_cert']
labels = labels + ['cpo'] + ['cpo_cert']

labels = [label for label in labels if '_es' not in label]

data = []

for model in labels:
Expand Down Expand Up @@ -882,7 +877,6 @@ def plot_log(system, task, algo, key, all_results):
ax = fig.add_subplot(111)

labels = sorted(all_results.keys())
labels = [label for label in labels if '_es' not in label]

colors = plt.colormaps['tab20'].colors

Expand All @@ -891,7 +885,7 @@ def plot_log(system, task, algo, key, all_results):
continue
if key in ['loss/policy_loss', 'loss/critic_loss'] and model == 'cpo':
continue
y = all_results[model][key][3] # savgol_filter(all_results[model][key][3], window_length=15, polyorder=3)
y = all_results[model][key][3]
ax.plot(all_results[model][key][1], y, label=model, color=colors[i])

ax.set_ylabel(key, weight='bold', fontsize=45, labelpad=10)
Expand Down

0 comments on commit cd27a45

Please sign in to comment.