Implementation of MILO, a model-based, offline imitation learning algorithm.
Link to pdf: https://arxiv.org/abs/2106.03207
After cloning this repository and installing the requirements, please run
cd milo && pip install -e .
cd mjrl && pip install -e .
The experiments are run using MuJoCo physics, which requires a license to install. Please follow the instructions on MuJoCo Website
The milo
package contains our imitation learning, model-based environment stack, and boilerplate code. We modified the mjrl
package to interface with our cost functions when doing model-based policy gradient. This modification can be seen in mjrl/mjrl/algos/batch_reinforce.py
. Note that we currently only support NPG/TRPO as our policy gradient algorithm; however, in principle one could replace this with other algorithms/repositories.
This repository supports 5 modified MuJoCo environments that can be found in milo/milo/gym_env
. They are
- Hopper-v4
- Walker2d-v4
- HalfCheetah-v4
- Ant-v4
- Humanoid-v4
If you would like to add an environment, register the environment in /milo/milo/gym_env/__init__.py
according to OpenAI Gym instructions.
Please download the datasets from this google drive link. Each environment will have 2 datasets: [ENV]_expert.pt
and [ENV]_offline.pt
.
In the data
directory, place the expert and offline datasets in the data/expert_data
and data/offline_data
direcotires respectively.
We provide an example run script for Hopper, example_run.sh
, that can be modified to be used with any other registered environment. To view all the possible arguments you can run please see the argparse in milo/milo/utils/arguments.py
.
To cite this work, please use the following citation. Note that this repository builds upon MJRL so please also cite any references noted in the README here.
@misc{chang2021mitigating,
title={Mitigating Covariate Shift in Imitation Learning via Offline Data Without Great Coverage},
author={Jonathan D. Chang and Masatoshi Uehara and Dhruv Sreenivas and Rahul Kidambi and Wen Sun},
year={2021},
eprint={2106.03207},
archivePrefix={arXiv},
primaryClass={cs.LG}
}