Skip to content

Latest commit

 

History

History
270 lines (202 loc) · 10.8 KB

pretrain.md

File metadata and controls

270 lines (202 loc) · 10.8 KB

Pipeline of Pre-Training RDT

Firstly, you need to install the prerequisites for RDT (see README). Then, you can install the prerequisites for TensorFlow Dataset (in another Conda environment).

Installation for TensorFlow Dataset

# Under the root directory of this repo
conda create -n rdt-data python=3.10
conda activate rdt-data

# Install all the prequisites
pip install -r requirements_data.txt
# Or you can manually install each package
pip install tfds-nightly gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
# If the speed is too slow, you can specify alternative sources (tfds-nightly is not available in Tsinghua mirror)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]

Download and Prepare Datasets

We introduce how to download each of our pre-training datasets. If you plan to pre-train on a subset of them, just download the ones you need. You can also fine-tune RDT through this pipeline only if your target dataset is included below or in the Google Cloud Storage.

Dataset Sample Percentage (%)
RT-1 Dataset 9.00
TACO Dataset 1.99
JACO Play Dataset 1.10
Cable Routing Dataset 0.27
NYU Door Opening 0.33
Viola 0.40
Berkeley UR5 1.06
TOTO 1.06
Kuka 1.66
Language Table 3.32
Columbia Cairlab Pusht Real 0.40
Stanford Kuka Multimodal Dataset 1.83
Stanford Hydra Dataset 0.80
Austin Buds Dataset 0.23
Maniskill Dataset 5.78
Furniture Bench Dataset 2.36
UCSD Kitchen Dataset 0.40
UCSD Pick And Place Dataset 1.23
Austin Sailor Dataset 0.50
Austin Sirius Dataset 0.80
BC Z 6.91
UTokyo PR2 Opening Fridge 0.30
UTokyo PR2 Tabletop Manipulation 0.50
UTokyo Xarm Pick And Place 0.33
UTokyo Xarm Bimanual 0.03
Berkeley MVP 0.73
Berkeley RPT 1.00
KAIST Nonprehensile 0.46
Tokyo U LSMO 0.23
DLR Sara Grid Clamp 0.03
Robocook 1.66
Imperialcollege Sawyer Wrist Cam 0.43
Iamlab CMU Pickup Insert 0.83
UTAustin Mutex 1.29
Fanuc Manipulation 0.66
Play Fusion 0.80
Droid 10.06
FMB 1.39
Dobb·E 1.20
QUT Dexterous Manipulation 0.46
Aloha Dataset 4.98
Mobile Aloha Dataset 4.98
Roboset 4.48
RH20T 10.99
Calvin Dataset 3.32
Bridgev2 7.44

Before everything, let's link the dataset directory on your disk to a subfolder of this repo:

ln -s /path/to/dataset /path/to/repo/RoboticsDiffusionTransformer/data/datasets

Open X-Embodiment

Specify the correct path to the gsutil in your Conda in this file.

Run the following commands to download our selected datasets for the Open X-Embodiment:

# Under the root directory of this repo
cd data/openx_embod
# Download all datasets
bash download_openx_embod.sh

Note: By modifying download_openx_embod.sh, you can download any dataset on the Google Cloud (as long as it can be downloaded with gsutil and is stored in TFRecord format), not just the ones we have listed.

Mobile ALOHA Dataset

Download the Mobile ALOHA Dataset from the official website to data/datasets/aloha, then run:

cd data/aloha
# Convert the dataset to TFRecord
python hdf5totfrecords.py

Bridgev2

Run:

cd data/bridgev2
# Download and preprocess the dataset
sh download.sh

Calvin

Run:

cd data/calvin
# Download and preprocess the dataset
sh download.sh
# Convert the dataset to TFRecord format
python hdf5totfrecords.py

RH20T

Download the RH20T Dataset from there official website to data/datasets/rh20t, then run

cd data/rh20t
# Convert the dataset to TFRecord
python hdf5totfrecords.py

RoboSet

Run:

cd data/roboset
# Download and preprocess the dataset
sh download.sh

If Want to Train on a New Dataset

If you want to train on a new dataset (e.g., my_pretrain_dataset) through this pre-training pipeline, you need to modify several files as follows:

1. configs/dataset_control_freq.json

Add the control frequency of your dataset.

2. data/preprocess_scripts/my_pretrain_dataset.py

If your dataset can be loaded by tfds.builder_from_directory(), then you only need to download it into the folder of Open X-Embodiment data/datasets/openx_embod and implement the function of process_step(). You may need to specify the tfds loading path in L78 (see this file). We refer to data/preprocess_scripts/droid.py for an example.

If not, you need to first convert it into TFRecords and then implement both load_dataset() and process_step(). We refer to data/agilex/hdf5totfrecords.py and data/preprocess_scripts/agilex.py for examples.

Here some descriptions:

load_dataset(seed: int)
  • Returns a dataset that supports iterator and repeat method with a random seed.
  • Suggested implementation: Use tf.data.Dataset.from_generator and tf.data.TFRecordDataset.
  • The iterator should return a subdataset that supports iterator representing one episode with the following structure:
    • step: A dataset object that supports iterator containing multiple frames per episode.
      • observation: A dictionary containing your images.
        • your_first_image_key: Your observation RGB image keys.
        • ...
      • other_attribute: Any other relevant attributes.
process_step(step: dict) -> dict

Processes a single frame and returns a dictionary with the following keys:

  • observation:
    • your_first_view_image: tf.Tensor: Your first view image.
    • arm_concat: tf.Tensor: Concatenation of physical states.
    • format: tf.constant(string): Format of arm_concat (e.g., arm_joint_pos_0,arm_joint_pos_1,arm_joint_pos_2).
  • action: Frame action (leave empty if there's none).
    • arm_concat: Same as in observation.
    • format: Same as in observation.
    • terminate: tf.Tensor: Boolean Tensor indicates if the episode ends.

IMPORTANT: You should only use TensorFlow functions for any branch or loop operations. For example, use tf.cond instead of if.

3. configs/dataset_img_keys.json

Add the image keys of your dataset. For example:

"my_pretrain_dataset": {
  "image_keys": [
    "exterior-cam",
    "right-wrist-cam",
    "left-wrist-cam",
    "left-wrist-cam"
  ],
  "image_mask": [1, 1, 1, 0]
}
  • To make TensorFlow happy, you have to specify four images in this order: exterior-cam, right-wrist-cam, left-wrist-cam, any-cam. Each key should correspond to your step attribute key of observation images.

  • If you only have a single wrist, just make it a right wrist.

  • The image_mask indicates whether each image is valid (1) or not (0).

  • What if you don’t have four images? Simply repeat the images in the following positions and set their masks to 0 (invalid).

  • The key order is strict. If you don't have the exterior camera but have both wrists, leave the exterior position blank (or pad) and use the following:

    "my_pretrain_dataset": {
      "image_keys": [
        "right-wrist-cam",
        "right-wrist-cam",
        "left-wrist-cam",
        "left-wrist-cam"
      ],
      "image_mask": [0, 1, 1, 0]
    }
  • During training, only the first three cameras will be used.

4. configs/dataset_stat.json

Compute the statistics (min, max, mean, and std) for your dataset:

# Use -h to see the full usage
python -m data.compute_dataset_stat --skip_exist

This will update the dataset_stat.json file with your dataset's statistics.

5. data/vla_dataset.py
  • Add your dataset to DATASET_NAMES_NOOPENX if it cannot be loaded by tfds.builder_from_directory().
  • If your dataset only contains action but no proprioception (i.e., robot state), add your dataset to DATASET_NAMES_NO_STATE in this file.
  • Normally, we consider the future state as the action of current timestep. If you want to use different actions, you should implement more functions. We refer to flatten_episode_agilex() in this file and _generate_json_state_agilex() in this file for examples. You may also refer to L318 in this file and L128 in this file for how to select your dataset and preprocess it differently.

Start Pre-Training

We employ a producer-consumer framework with TensorFlow Dataset for fast data loading. Since most of the datasets in the Open X-Embodiment are stored in the form of TFRecord, we convert all pre-training datasets into TFRecord for storage. In pre-training, we use the producer process to decompress the data from TFRecord and store it in a buffer on the hard disk. At the same time, we use the consumer process to read data from the buffer in a disorderly order and feed it to the model training. This not only decouples the TensorFlow and PyTorch environments but also alleviates the training performance loss caused by the small size of the shuffling buffer in the memory.

This file includes configurations relevant to model architecture (including number of heads, hidden dimension, and so on) and data processing. You may need to modify buf_path (L22) to your real buffer path. This buffer is used as disk shuffling buffer for data loading.

Configurations relevant to training are passed through Command Line Arguments. Use python main.py -h to see the descriptions. We provide an example pre-training script in this file (pretrain.sh). You may need to modify some of the parameters in this file, such as CUTLASS_PATH and WANDB_PROJECT.

You may need to modify the list of pre-training datasets in this file and their corresponding sampling weights in this file. If you want to fine-tune RDT through this pipeline, you may need to remove abundant datasets in the list.

Before start pre-training, we first start the data producer process (if you use multiple nodes, you should run this command in each node):

# Under the root directory of this repo
conda activate rdt-data
# Use -h to see the full usage
python -m data.producer --fill_up
# Please proceed to the next step AFTER finishing the filling up process

Then, we run the pre-training script:

source pretrain.sh

Note: You can monitor the training process by observing loss (through a long window moving average), overall_avg_sample_mse, and the sampling MSE of each dataset in Wandb or TensorBoard. We empirically found that the lower the overall_avg_sample_mse, the better the model performs.