This repository is a supplementary code with experiments for my analysis of AdaptiveBC (Zhao et al.) and Improved TD3+BC (Beeson et al.). Link to the document with analysis: https://api.wandb.ai/links/ummagumm-a/0n0kqm6r.
This is a fork, so for CORL-specific installation guide go either to their page or to the CORL-specific README
section below.
I assume you followed CORL's installation guide and you are now in the docker container.
First, create a working directory and go there. If you don't, there will be a conflict between names of tinkoff's CORL and my forked CORL:
mkdir workdir && cd workdir
.
Create an ssh-key inside docker container to be able to clone repositories:
ssh-keygen
('Enter' everywhere)
Copy the output of cat ~/.ssh/id_key.pub
and add an ssh key to your github account here: https://github.com/settings/keys.
Next, clone repositories of AdaptiveBC and Improved TD3+BC:
git clone [email protected]:ummagumm-a/adaptive_bc.git
git clone [email protected]:ummagumm-a/CORL.git
To run experiments with WandB logger, you need to first log-in:
wandb login
insert API key from here: https://wandb.ai/settings.
Install additional libraries:
pip install optuna stable_baselines3 scikit-learn
To run AdaptiveBC experiments:
cd adaptive_bc
python main.py
To run Improved TD3+BC experiments:
cd CORL
python python algorithms/improved_td3_bc.py
I used implementation of AdaptiveBC authors: https://github.com/zhaoyi11/adaptive_bc. The only changes I made were additional logged variables and technical adjustments of 'cuda' usage. For main implementation details I refer you to the original repo.
I used CORL's algorithms/td3_bc.py
code as a foundation for my implementation of Beeson's paper. I extended CORL's code to Online-learning setting and implemented Beeson's scheduling procedure. All implementation is located at algorithms/improved_td3_bc.py
.
To save time for my experiments I pretrained models on offline dataset with 5 different seed. File pretrain_script
contains command to run one of models.
First, install MySQL for distributed optimization, create a database and create an optuna study:
apt update
apt install sudo
sudo apt install mysql-server
sudo service mysql start
mysql -u root -e "CREATE DATABASE IF NOT EXISTS improved_td3_bc_tune_replay"
sudo apt-get install python3-mysqldb
pip install mysqlclient
optuna create-study --study-name "improved_td3_bc_tune_replay" --storage="mysql://root@localhost/improved_td3_bc_tune_replay"
To start hyperparameter optimization, run:
python algorithms/improved_td3_bc.py --hyper_tune=True --load_model_for_tune=<path>
Note that the current code assumes that there are 5 pretrained model located in path
.
If you don't want to run the whole optimization procedure from scratch - use my dumps of databases located in tune_trials_dump
.
I saved some of optuna-generated graphs to optuna_results
. These include importance graphs, contour plots, and pareto front plots.
🧵 CORL is an Offline Reinforcement Learning library that provides high-quality and easy-to-follow single-file implementations of SOTA ORL algorithms. Each implementation is backed by a research-friendly codebase, allowing you to run or tune thousands of experiments. Heavily inspired by cleanrl for online RL, check them out too!
- 📜 Single-file implementation
- 📈 Benchmarked Implementation for N algorithms
- 🖼 Weights and Biases integration
git clone https://github.com/tinkoff-ai/CORL.git && cd CORL
pip install -r requirements/requirements_dev.txt
# alternatively, you could use docker
docker build -t <image_name> .
docker run --gpus=all -it --rm --name <container_name> <image_name>
For learning curves and all the details, you can check the links above. Here, we report reproduced final and best scores. Note that thay differ by a big margin, and some papers may use different approaches not making it always explicit which one reporting methodology they chose.
Task-Name | BC | BC-10% | TD3 + BC | CQL | IQL | AWAC | SAC-N | EDAC | DT | LB-SAC |
---|---|---|---|---|---|---|---|---|---|---|
halfcheetah-medium-v2 | 42.40±0.21 | 42.46±0.81 | 48.10±0.21 | 47.08±0.19 | 48.31±0.11 | 50.01±0.30 | 68.20±1.48 | 67.70±1.20 | 42.20±0.30 | 71.21±1.35 |
halfcheetah-medium-expert-v2 | 55.95±8.49 | 90.10±2.83 | 90.78±6.98 | 95.98±0.83 | 94.55±0.21 | 95.29±0.91 | 98.96±10.74 | 104.76±0.74 | 91.55±1.10 | 106.57±3.90 |
halfcheetah-medium-replay-v2 | 35.66±2.68 | 23.59±8.02 | 44.84±0.68 | 45.19±0.58 | 43.53±0.43 | 44.91±1.30 | 60.70±1.17 | 62.06±1.27 | 38.91±0.57 | 64.10±0.82 |
hopper-medium-v2 | 53.51±2.03 | 55.48±8.43 | 60.37±4.03 | 64.98±6.12 | 62.75±6.02 | 63.69±4.29 | 40.82±11.44 | 101.70±0.32 | 65.10±1.86 | 103.75±0.07 |
hopper-medium-expert-v2 | 52.30±4.63 | 111.16±1.19 | 101.17±10.48 | 93.89±14.34 | 106.24±6.09 | 105.29±7.19 | 101.31±13.43 | 105.19±11.64 | 110.44±0.39 | 110.93±0.51 |
hopper-medium-replay-v2 | 29.81±2.39 | 70.42±9.99 | 64.42±24.84 | 87.67±14.42 | 84.57±13.49 | 98.15±2.85 | 100.33±0.90 | 99.66±0.94 | 81.77±7.93 | 102.53±0.92 |
walker2d-medium-v2 | 63.23±18.76 | 67.34±5.97 | 82.71±5.51 | 80.38±3.45 | 84.03±5.42 | 69.39±31.97 | 87.47±0.76 | 93.36±1.60 | 67.63±2.93 | 90.95±0.65 |
walker2d-medium-expert-v2 | 98.96±18.45 | 108.70±0.29 | 110.03±0.41 | 109.68±0.52 | 111.68±0.56 | 111.16±2.41 | 114.93±0.48 | 114.75±0.86 | 107.11±1.11 | 113.46±2.31 |
walker2d-medium-replay-v2 | 21.80±11.72 | 54.35±7.32 | 85.62±4.63 | 79.24±4.97 | 82.55±8.00 | 71.73±13.98 | 78.99±0.58 | 87.10±3.21 | 59.86±3.15 | 87.95±1.43 |
locomotion average | 50.40 | 69.29 | 76.45 | 78.23 | 79.80 | 78.85 | 83.52 | 92.92 | 73.84 | 94.60 |
Task-Name | BC | BC-10% | TD3 + BC | CQL | IQL | AWAC | SAC-N | EDAC | DT |
---|---|---|---|---|---|---|---|---|---|
maze2d-umaze-v1 | 0.36±10.03 | 12.18±4.95 | 29.41±14.22 | -14.83±0.47 | 37.69±1.99 | 68.30±25.72 | 130.59±19.08 | 95.26±7.37 | 18.08±29.35 |
maze2d-medium-v1 | 0.79±3.76 | 14.25±2.69 | 59.45±41.86 | 86.62±11.11 | 35.45±0.98 | 82.66±46.71 | 88.61±21.62 | 57.04±3.98 | 31.71±30.40 |
maze2d-large-v1 | 2.26±5.07 | 11.32±5.88 | 97.10±29.34 | 33.22±43.66 | 49.64±22.02 | 218.87±3.96 | 204.76±1.37 | 95.60±26.46 | 35.66±32.56 |
maze2d average | 1.13 | 12.58 | 61.99 | 35.00 | 40.92 | 123.28 | 141.32 | 82.64 | 28.48 |
Task-Name | BC | BC-10% | TD3 + BC | CQL | IQL | AWAC | SAC-N | EDAC | DT |
---|---|---|---|---|---|---|---|---|---|
antmaze-umaze-v0 | 51.50±8.81 | 67.75±6.40 | 93.25±1.50 | 72.75±5.32 | 74.50±11.03 | 63.50±9.33 | 0.00±0.00 | 29.25±33.35 | 51.75±11.76 |
antmaze-medium-play-v0 | 0.00±0.00 | 2.50±1.91 | 0.00±0.00 | 0.00±0.00 | 71.50±12.56 | 0.00±0.00 | 0.00±0.00 | 0.00±0.00 | 0.00±0.00 |
antmaze-large-play-v0 | 0.00±0.00 | 0.00±0.00 | 0.00±0.00 | 0.00±0.00 | 40.75±12.69 | 0.00±0.00 | 0.00±0.00 | 0.00±0.00 | 0.00±0.00 |
antmaze average | 17.17 | 23.42 | 31.08 | 24.25 | 62.25 | 21.17 | 0.00 | 9.75 | 17.25 |
Task-Name | BC | BC-10% | TD3 + BC | CQL | IQL | AWAC | SAC-N | EDAC | DT | LB-SAC |
---|---|---|---|---|---|---|---|---|---|---|
halfcheetah-medium-v2 | 43.60±0.16 | 43.90±0.15 | 48.93±0.13 | 47.45±0.10 | 48.77±0.06 | 50.87±0.21 | 72.21±0.35 | 69.72±1.06 | 42.73±0.11 | 71.82±0.68 |
halfcheetah-medium-expert-v2 | 79.69±3.58 | 94.11±0.25 | 96.59±1.01 | 96.74±0.14 | 95.83±0.38 | 96.87±0.31 | 111.73±0.55 | 110.62±1.20 | 93.40±0.25 | 110.37±0.47 |
halfcheetah-medium-replay-v2 | 40.52±0.22 | 42.27±0.53 | 45.84±0.30 | 46.38±0.14 | 45.06±0.16 | 46.57±0.27 | 67.29±0.39 | 66.55±1.21 | 40.31±0.32 | 66.14±1.06 |
hopper-medium-v2 | 69.04±3.35 | 73.84±0.43 | 70.44±1.37 | 77.47±6.00 | 80.74±1.27 | 99.40±1.12 | 101.79±0.23 | 103.26±0.16 | 69.42±4.21 | 103.88±0.17 |
hopper-medium-expert-v2 | 90.63±12.68 | 113.13±0.19 | 113.22±0.50 | 112.74±0.07 | 111.79±0.47 | 113.37±0.63 | 111.24±0.17 | 111.80±0.13 | 111.18±0.24 | 110.93±0.51 |
hopper-medium-replay-v2 | 68.88±11.93 | 90.57±2.38 | 98.12±1.34 | 102.20±0.38 | 102.33±0.44 | 101.76±0.43 | 103.83±0.61 | 103.28±0.57 | 88.74±3.49 | 104.00±0.94 |
walker2d-medium-v2 | 80.64±1.06 | 82.05±1.08 | 86.91±0.32 | 84.57±0.15 | 87.99±0.83 | 86.22±4.58 | 90.17±0.63 | 95.78±1.23 | 74.70±0.64 | 90.95±0.65 |
walker2d-medium-expert-v2 | 109.95±0.72 | 109.90±0.10 | 112.21±0.07 | 111.63±0.20 | 113.19±0.33 | 113.40±2.57 | 116.93±0.49 | 116.52±0.86 | 108.71±0.39 | 113.46±2.31 |
walker2d-medium-replay-v2 | 48.41±8.78 | 76.09±0.47 | 91.17±0.83 | 89.34±0.59 | 91.85±2.26 | 87.06±0.93 | 85.18±1.89 | 89.69±1.60 | 68.22±1.39 | 92.25±2.20 |
locomotion average | 70.15 | 80.65 | 84.83 | 85.39 | 86.40 | 88.39 | 95.60 | 96.36 | 77.49 | 95.97 |
Task-Name | BC | BC-10% | TD3 + BC | CQL | IQL | AWAC | SAC-N | EDAC | DT |
---|---|---|---|---|---|---|---|---|---|
maze2d-umaze-v1 | 16.09±1.00 | 22.49±1.75 | 99.33±18.66 | 84.92±34.40 | 44.04±3.02 | 141.92±12.88 | 153.12±7.50 | 149.88±2.27 | 63.83±20.04 |
maze2d-medium-v1 | 19.16±1.44 | 27.64±2.16 | 150.93±4.50 | 137.52±9.83 | 92.25±40.74 | 160.95±11.64 | 93.80±16.93 | 154.41±1.82 | 68.14±14.15 |
maze2d-large-v1 | 20.75±7.69 | 41.83±4.20 | 197.64±6.07 | 153.29±12.86 | 138.70±44.70 | 228.00±2.06 | 207.51±1.11 | 182.52±3.10 | 50.25±22.33 |
maze2d average | 18.67 | 30.65 | 149.30 | 125.25 | 91.66 | 176.96 | 151.48 | 162.27 | 60.74 |
Task-Name | BC | BC-10% | TD3 + BC | CQL | IQL | AWAC | SAC-N | EDAC | DT |
---|---|---|---|---|---|---|---|---|---|
antmaze-umaze-v0 | 71.25±9.07 | 79.50±2.38 | 97.75±1.50 | 85.00±3.56 | 87.00±2.94 | 74.75±8.77 | 0.00±0.00 | 75.00±27.51 | 60.50±3.11 |
antmaze-medium-play-v0 | 4.75±2.22 | 8.50±3.51 | 6.00±2.00 | 3.00±0.82 | 86.00±2.16 | 14.00±11.80 | 0.00±0.00 | 0.00±0.00 | 0.25±0.50 |
antmaze-large-play-v0 | 0.75±0.50 | 11.75±2.22 | 0.50±0.58 | 0.50±0.58 | 53.00±6.83 | 0.00±0.00 | 0.00±0.00 | 0.00±0.00 | 0.00±0.00 |
antmaze average | 25.58 | 33.25 | 34.75 | 29.50 | 75.33 | 29.58 | 0.00 | 25.00 | 20.25 |
If you use CORL in your work, please use the following bibtex
@inproceedings{
tarasov2022corl,
title={{CORL}: Research-oriented Deep Offline Reinforcement Learning Library},
author={Denis Tarasov and Alexander Nikulin and Dmitry Akimov and Vladislav Kurenkov and Sergey Kolesnikov},
booktitle={3rd Offline RL Workshop: Offline RL as a ''Launchpad''},
year={2022},
url={https://openreview.net/forum?id=SyAS49bBcv}
}