Skip to content

Commit

Permalink
Adding basic training of PPO on crazyflie with MPSF in the loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Jan 5, 2024
1 parent 5555e97 commit ef6292b
Show file tree
Hide file tree
Showing 11 changed files with 324 additions and 126 deletions.
50 changes: 50 additions & 0 deletions experiments/crazyflie/config_overrides/crazyflie_track.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,56 @@ task_config:
randomized_init: False
randomized_inertial_prop: False

init_state_randomization_info:
init_x:
distrib: 'uniform'
low: -0.75
high: 0.75
init_x_dot:
distrib: 'uniform'
low: -0.5
high: 0.5
init_y:
distrib: 'uniform'
low: -0.75
high: 0.75
init_y_dot:
distrib: 'uniform'
low: -0.5
high: 0.5
init_z:
distrib: 'uniform'
low: 0.5
high: 2
init_z_dot:
distrib: 'uniform'
low: -1
high: 1
init_phi:
distrib: 'uniform'
low: -0.2
high: 0.2
init_theta:
distrib: 'uniform'
low: -0.2
high: 0.2
init_psi:
distrib: 'uniform'
low: -0.2
high: 0.2
init_p:
distrib: 'uniform'
low: -1
high: 1
init_q:
distrib: 'uniform'
low: -1
high: 1
init_r:
distrib: 'uniform'
low: -1
high: 1

task: traj_tracking
task_info:
trajectory_type: figure8
Expand Down
47 changes: 47 additions & 0 deletions experiments/crazyflie/config_overrides/ppo_crazyflie.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
algo: ppo
algo_config:
# model args
hidden_dim: 128
activation: 'leaky_relu'
norm_obs: False
norm_reward: False
clip_obs: 10.0
clip_reward: 10.0

# loss args
gamma: 0.98
use_gae: False
gae_lambda: 0.8
use_clipped_value: False
clip_param: 0.1
target_kl: 1.587713889686473e-07
entropy_coef: 0.00010753631441212628

# optim args
opt_epochs: 5
mini_batch_size: 128
actor_lr: 0.0007948148615930024
critic_lr: 0.007497368468753617
max_grad_norm: 0.5

# runner args
max_env_steps: 100000
num_workers: 1
rollout_batch_size: 1
rollout_steps: 500
deque_size: 10
eval_batch_size: 10

# misc
log_interval: 6000
save_interval: 0
num_checkpoints: 0
eval_interval: 6000
eval_save_best: True
tensorboard: False

# safety filter
filter_train_actions: True
penalize_sf_diff: True
sf_penalty: 1
use_safe_reset: True
40 changes: 30 additions & 10 deletions experiments/crazyflie/crazyflie_experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'''Running MPSC using the crazyflie firmware. '''

import sys
import shutil
import time
sys.path.insert(0, '/home/federico/GitHub/safe-control-gym')

Expand Down Expand Up @@ -54,8 +55,11 @@ def run(gui=False, plot=True, training=False, certify=True, curr_path='.'):
**config.task_config)

FIRMWARE_FREQ = 500

config.task_config.gui = gui
config.task_config['ctrl_freq'] = FIRMWARE_FREQ
config.algo_config['training'] = False

env_func_500 = partial(make,
config.task,
**config.task_config)
Expand All @@ -80,11 +84,20 @@ def run(gui=False, plot=True, training=False, certify=True, curr_path='.'):
ctrl = make(config.algo,
env_func,
**config.algo_config)
ctrl.gain = lqr_gain
ctrl.model.U_EQ = np.array([[0, 0]]).T

ctrl.env.X_GOAL = full_trajectory
ctrl.env.TASK = Task.TRAJ_TRACKING
if config.algo in ['ppo', 'sac', 'safe_explorer_ppo', 'cpo']:
# Load state_dict from trained.
ctrl.load(f'{curr_path}/models/rl_models/{config.algo}/model_best.pt')

# Remove temporary files and directories
shutil.rmtree(f'{curr_path}/temp', ignore_errors=True)
ctrl.X_GOAL = full_trajectory
else:
ctrl.gain = lqr_gain
ctrl.model.U_EQ = np.array([[0, 0]]).T

ctrl.env.X_GOAL = full_trajectory
ctrl.env.TASK = Task.TRAJ_TRACKING

if certify is True:
# Setup MPSC.
Expand Down Expand Up @@ -114,8 +127,8 @@ def run(gui=False, plot=True, training=False, certify=True, curr_path='.'):
curr_obs = np.atleast_2d(obs[0:4]).T
curr_obs = curr_obs.reshape((4, 1))
info['current_step'] = i
new_act = ctrl.select_action(curr_obs, info)
new_act = np.clip(new_act, np.array([[-0.25, -0.25]]).T, np.array([[0.25, 0.25]]).T)
new_act = np.squeeze(ctrl.select_action(curr_obs, info))
new_act = np.clip(new_act, np.array([-0.25, -0.25]), np.array([0.25, 0.25]))
actions_uncert.append(new_act)
if certify is True:
certified_action, success = safety_filter.certify_action(curr_obs, new_act, info)
Expand Down Expand Up @@ -147,15 +160,15 @@ def run(gui=False, plot=True, training=False, certify=True, curr_path='.'):

states = np.array(states)
actions_uncert = np.array(actions_uncert)
print('Number of Max Inputs: ', np.sum(np.abs(actions_uncert) == 0.25))
print(f'Number of Max Inputs: {np.sum(np.abs(actions_uncert) == 0.25)}/{2*len(actions_uncert)}')
actions_cert = np.array(actions_cert)
corrections = np.squeeze(actions_cert) - np.squeeze(actions_uncert)

# Close the environment
env.close()
print('Elapsed Time: ', time.time() - ep_start)
print('NUM ERRORS POS: ', np.sum(np.abs(states[:, 0]) >= 0.75))
print('NUM ERRORS VEL: ', np.sum(np.abs(states[:, 1]) >= 0.5))
print('NUM VIOLATIONS POS: ', np.sum(np.abs(states[:, 0]) >= 0.75))
print('NUM VIOLATIONS VEL: ', np.sum(np.abs(states[:, 1]) >= 0.5))
print('Rate of change (inputs): ', np.linalg.norm(get_discrete_derivative(np.atleast_2d(actions_cert).T, CTRL_FREQ)))
if certify:
print(f'Feasible steps: {float(successes)}/{CTRL_FREQ*env.EPISODE_LEN_SEC}')
Expand All @@ -169,7 +182,14 @@ def run(gui=False, plot=True, training=False, certify=True, curr_path='.'):
plt.legend()
plt.show()

plt.plot(states[:, 0], label='traj')
plt.plot(states[:, 1], label='x vel')
plt.plot(states[:, 3], label='y vel')
plt.plot(states[:, 5], label='z vel')
plt.legend()
plt.show()

plt.plot(states[:, 0], label='x traj')
plt.plot(states[:, 2], label='y traj')
plt.plot(full_trajectory[:, 0], label='ref')
plt.legend()
plt.show()
Expand Down
Binary file not shown.
2 changes: 0 additions & 2 deletions experiments/crazyflie/plotting_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
Legend._ncol = property(lambda self: self._ncols)


k

plot = True
save_figs = False

Expand Down
11 changes: 7 additions & 4 deletions experiments/crazyflie/test_crazyflie.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,26 @@ TASK='track'
MPSC='nl_mpsc'
# MPSC='linear_mpsc'

# MPSC_COST='one_step_cost'
MPSC_COST='one_step_cost'
# MPSC_COST='constant_cost'
# MPSC_COST='regularized_cost'
# MPSC_COST='lqr_cost'
MPSC_COST='precomputed_cost'
# MPSC_COST='precomputed_cost'
# MPSC_COST='learned_cost'

MPSC_COST_HORIZON=10

# python3 ./train_rl.py \
python3 ./crazyflie_experiment.py \
--task quadrotor \
--algo lqr \
--algo ppo \
--safety_filter ${MPSC} \
--overrides \
./config_overrides/crazyflie_${TASK}.yaml \
./config_overrides/lqr_crazyflie.yaml \
./config_overrides/ppo_crazyflie.yaml \
./config_overrides/nl_mpsc.yaml \
--output_dir ./models/rl_models/ppo/ \
--kv_overrides \
sf_config.cost_function=${MPSC_COST} \
sf_config.mpsc_cost_horizon=${MPSC_COST_HORIZON} \
# task_config.init_state=None
113 changes: 113 additions & 0 deletions experiments/crazyflie/train_rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
'''Running MPSC using the crazyflie firmware. '''

import os
import shutil
import sys
sys.path.insert(0, '/home/federico/GitHub/safe-control-gym')

from functools import partial

import numpy as np

from experiments.crazyflie.crazyflie_utils import gen_traj
from safe_control_gym.utils.configuration import ConfigFactory
from safe_control_gym.utils.plotting import plot_from_logs
from safe_control_gym.utils.registration import make
from safe_control_gym.utils.utils import mkdirs, set_device_from_config, set_seed_from_config

try:
import pycffirmware
except ImportError:
FIRMWARE_INSTALLED = False
else:
FIRMWARE_INSTALLED = True
finally:
print('Module \'cffirmware\' available:', FIRMWARE_INSTALLED)


def train():
'''The main function creating, running, and closing an environment over N episodes. '''

# Define arguments.
fac = ConfigFactory()
config = fac.merge()
config.algo_config['training'] = True

shutil.rmtree(config.output_dir, ignore_errors=True)

set_seed_from_config(config)
set_device_from_config(config)
CTRL_FREQ = config.task_config['ctrl_freq']

env_func = partial(make,
config.task,
output_dir=config.output_dir,
**config.task_config)

FIRMWARE_FREQ = 500
config.task_config['ctrl_freq'] = FIRMWARE_FREQ
env_func_500 = partial(make,
config.task,
output_dir=config.output_dir,
**config.task_config)

# Create environment.
firmware_wrapper = make('firmware', env_func_500, FIRMWARE_FREQ, CTRL_FREQ)
_, _ = firmware_wrapper.reset()
env = firmware_wrapper.env

# Create trajectory.
full_trajectory = gen_traj(CTRL_FREQ, env.EPISODE_LEN_SEC)
full_trajectory = np.hstack((full_trajectory, full_trajectory))

# Setup controller.
ctrl = make(config.algo,
env_func,
checkpoint_path=os.path.join(config.output_dir, 'model_latest.pt'),
output_dir=config.output_dir,
seed=1,
**config.algo_config)
ctrl.reset()

ctrl.firmware_wrapper = firmware_wrapper
ctrl.X_GOAL = full_trajectory
ctrl.CTRL_DT = 1.0 / CTRL_FREQ

# Setup MPSC.
if config.algo in ['ppo', 'sac']:
safety_filter = make(config.safety_filter,
env_func,
**config.sf_config)
safety_filter.reset()
safety_filter.load(path=f'./models/mpsc_parameters/{config.safety_filter}_crazyflie_track.pkl')

safety_filter.env.X_GOAL = full_trajectory
ctrl.safety_filter = safety_filter

ctrl.learn()
ctrl.close()
print('Training done.')

# with open(os.path.join(config.output_dir, 'config.yaml'), 'w', encoding='UTF-8') as file:
# yaml.dump(munch.unmunchify(config), file, default_flow_style=False)

# make_plots(config)


def make_plots(config):
'''Produces plots for logged stats during training.
Usage
* use with `--func plot` and `--restore {dir_path}` where `dir_path` is
the experiment folder containing the logs.
* save figures under `dir_path/plots/`.
'''
# Define source and target log locations.
log_dir = os.path.join(config.output_dir, 'logs')
plot_dir = os.path.join(config.output_dir, 'plots')
mkdirs(plot_dir)
plot_from_logs(log_dir, plot_dir, window=3)
print('Plotting done.')


if __name__ == '__main__':
train()
2 changes: 1 addition & 1 deletion safe_control_gym/controllers/firmware/firmware_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def reset(self):
# Initialize controller
if self.CONTROLLER == 'pid':
firm.controllerPidInit()
print('PID controller init test:', firm.controllerPidTest())
# print('PID controller init test:', firm.controllerPidTest())
elif self.CONTROLLER == 'mellinger':
firm.controllerMellingerInit()
assert (self.firmware_freq == 500), 'Mellinger controller requires a firmware frequency of 500Hz.'
Expand Down
Loading

0 comments on commit ef6292b

Please sign in to comment.