forked from utiasDSL/safe-control-gym
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding basic training of PPO on crazyflie with MPSF in the loop
- Loading branch information
1 parent
5555e97
commit ef6292b
Showing
11 changed files
with
324 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
algo: ppo | ||
algo_config: | ||
# model args | ||
hidden_dim: 128 | ||
activation: 'leaky_relu' | ||
norm_obs: False | ||
norm_reward: False | ||
clip_obs: 10.0 | ||
clip_reward: 10.0 | ||
|
||
# loss args | ||
gamma: 0.98 | ||
use_gae: False | ||
gae_lambda: 0.8 | ||
use_clipped_value: False | ||
clip_param: 0.1 | ||
target_kl: 1.587713889686473e-07 | ||
entropy_coef: 0.00010753631441212628 | ||
|
||
# optim args | ||
opt_epochs: 5 | ||
mini_batch_size: 128 | ||
actor_lr: 0.0007948148615930024 | ||
critic_lr: 0.007497368468753617 | ||
max_grad_norm: 0.5 | ||
|
||
# runner args | ||
max_env_steps: 100000 | ||
num_workers: 1 | ||
rollout_batch_size: 1 | ||
rollout_steps: 500 | ||
deque_size: 10 | ||
eval_batch_size: 10 | ||
|
||
# misc | ||
log_interval: 6000 | ||
save_interval: 0 | ||
num_checkpoints: 0 | ||
eval_interval: 6000 | ||
eval_save_best: True | ||
tensorboard: False | ||
|
||
# safety filter | ||
filter_train_actions: True | ||
penalize_sf_diff: True | ||
sf_penalty: 1 | ||
use_safe_reset: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,6 @@ | |
Legend._ncol = property(lambda self: self._ncols) | ||
|
||
|
||
k | ||
|
||
plot = True | ||
save_figs = False | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
'''Running MPSC using the crazyflie firmware. ''' | ||
|
||
import os | ||
import shutil | ||
import sys | ||
sys.path.insert(0, '/home/federico/GitHub/safe-control-gym') | ||
|
||
from functools import partial | ||
|
||
import numpy as np | ||
|
||
from experiments.crazyflie.crazyflie_utils import gen_traj | ||
from safe_control_gym.utils.configuration import ConfigFactory | ||
from safe_control_gym.utils.plotting import plot_from_logs | ||
from safe_control_gym.utils.registration import make | ||
from safe_control_gym.utils.utils import mkdirs, set_device_from_config, set_seed_from_config | ||
|
||
try: | ||
import pycffirmware | ||
except ImportError: | ||
FIRMWARE_INSTALLED = False | ||
else: | ||
FIRMWARE_INSTALLED = True | ||
finally: | ||
print('Module \'cffirmware\' available:', FIRMWARE_INSTALLED) | ||
|
||
|
||
def train(): | ||
'''The main function creating, running, and closing an environment over N episodes. ''' | ||
|
||
# Define arguments. | ||
fac = ConfigFactory() | ||
config = fac.merge() | ||
config.algo_config['training'] = True | ||
|
||
shutil.rmtree(config.output_dir, ignore_errors=True) | ||
|
||
set_seed_from_config(config) | ||
set_device_from_config(config) | ||
CTRL_FREQ = config.task_config['ctrl_freq'] | ||
|
||
env_func = partial(make, | ||
config.task, | ||
output_dir=config.output_dir, | ||
**config.task_config) | ||
|
||
FIRMWARE_FREQ = 500 | ||
config.task_config['ctrl_freq'] = FIRMWARE_FREQ | ||
env_func_500 = partial(make, | ||
config.task, | ||
output_dir=config.output_dir, | ||
**config.task_config) | ||
|
||
# Create environment. | ||
firmware_wrapper = make('firmware', env_func_500, FIRMWARE_FREQ, CTRL_FREQ) | ||
_, _ = firmware_wrapper.reset() | ||
env = firmware_wrapper.env | ||
|
||
# Create trajectory. | ||
full_trajectory = gen_traj(CTRL_FREQ, env.EPISODE_LEN_SEC) | ||
full_trajectory = np.hstack((full_trajectory, full_trajectory)) | ||
|
||
# Setup controller. | ||
ctrl = make(config.algo, | ||
env_func, | ||
checkpoint_path=os.path.join(config.output_dir, 'model_latest.pt'), | ||
output_dir=config.output_dir, | ||
seed=1, | ||
**config.algo_config) | ||
ctrl.reset() | ||
|
||
ctrl.firmware_wrapper = firmware_wrapper | ||
ctrl.X_GOAL = full_trajectory | ||
ctrl.CTRL_DT = 1.0 / CTRL_FREQ | ||
|
||
# Setup MPSC. | ||
if config.algo in ['ppo', 'sac']: | ||
safety_filter = make(config.safety_filter, | ||
env_func, | ||
**config.sf_config) | ||
safety_filter.reset() | ||
safety_filter.load(path=f'./models/mpsc_parameters/{config.safety_filter}_crazyflie_track.pkl') | ||
|
||
safety_filter.env.X_GOAL = full_trajectory | ||
ctrl.safety_filter = safety_filter | ||
|
||
ctrl.learn() | ||
ctrl.close() | ||
print('Training done.') | ||
|
||
# with open(os.path.join(config.output_dir, 'config.yaml'), 'w', encoding='UTF-8') as file: | ||
# yaml.dump(munch.unmunchify(config), file, default_flow_style=False) | ||
|
||
# make_plots(config) | ||
|
||
|
||
def make_plots(config): | ||
'''Produces plots for logged stats during training. | ||
Usage | ||
* use with `--func plot` and `--restore {dir_path}` where `dir_path` is | ||
the experiment folder containing the logs. | ||
* save figures under `dir_path/plots/`. | ||
''' | ||
# Define source and target log locations. | ||
log_dir = os.path.join(config.output_dir, 'logs') | ||
plot_dir = os.path.join(config.output_dir, 'plots') | ||
mkdirs(plot_dir) | ||
plot_from_logs(log_dir, plot_dir, window=3) | ||
print('Plotting done.') | ||
|
||
|
||
if __name__ == '__main__': | ||
train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.