-
Notifications
You must be signed in to change notification settings - Fork 2
Training your own models
This page gives a tutorial on how to train the models mentioned in the Tango paper.
Prerequisite: Follow the instructions on the Environment Setup Page.
Pre-trained models: All the trained models mentioned in the Tango paper can be found here.
All the models mentioned in the paper can be trained through the command
$ python3 train.py $DOMAIN action $MODEL_NAME train
Here DOMAIN can be home/factory.
MODEL_NAME specifies the specific PyTorch model that you want to train. Look at src/GNN/models.py
(ToolNet) or src/GNN/action_models.py
(Tango) to specify the name. They are specified here for reference.
MODEL_NAME | Name in paper |
---|---|
GGCN_Auto_Action | GGCN+Auto (Baseline) |
GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Action | Tango |
Final_GGCN_Action | - GGCN |
Final_Attn_Action | - Goal-Conditioned Attn |
Final_Cons_Action | - Constraints |
Final_Auto_Action | - Autoregression |
Final_Aseq_Action | - Temporal Action History |
Final_L_Action | - Factored Likelihood |
This command will train MODEL_NAME on the training dataset for NUM_EPOCHS
epochs specified in src/GNN/CONSTANTS.py
. It will save a checkpoint file trained_models/DOMAIN/MODEL_NAME_EPOCH.ckpt
after the EPOCH
epoch. In the end, it will output the epoch (say N
) corresponding to the maximum policy accuracy using early stopping criteria. Rename the trained_models/DOMAIN/MODEL_NAME_N.ckpt
file to trained_models/DOMAIN/MODEL_NAME_Trained.ckpt
for testing. You may delete the other checkpoint files.
To train the best model in home domain:
python3 train.py home action GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Action train
To train the best model in factory domain:
python3 train.py factory action GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Action train
To train the ablated "- GGCN" model in home domain:
python3 train.py home action Final_GGCN_Action train
BSD-2-Clause. Copyright (c) 2020, Shreshth Tuli, Rajas Basal, Rohan Paul and Mausam. All rights reserved.