-
Notifications
You must be signed in to change notification settings - Fork 2
Replicating results
This page explains how to replicate results in the Tango paper.
Prerequisite: Follow the instructions on the Environment Setup Page. You should also have all trained models stored in ./trained_models/home/
for home domain and ./trained_models/factory/
for factory domain. You may train your models by following this wiki or download pre-trained models from here.
We use two evaluation metrics:
-
Action Prediction Accuracy: This is the fraction of tool interactions predicted by the model, which matched the human demonstrated action
a
for a given states
. - Plan Execution Accuracy: This is the fraction of estimated plans that are successful, i.e., can be executed by the robot in simulation and attain the intended goal (with an upper bound of 50 actions in the plan).
These two metrics are calculated on the Test set and Generalization Test (GenTest) set.
Accuracy of any model can be obtained using the following command:
$ python3 train.py $DOMAIN action $MODEL_NAME $EXEC_TYPE
This command looks for the file trained_models/DOMAIN/MODEL_NAME_Trained.ckpt
if you are using a pre-trained model, so make sure that the file corresponding to this command is present.
Here DOMAIN can be home/factory.
MODEL_NAME specifies the specific PyTorch model that you want to train. Look at src/GNN/models.py
(ToolNet) or src/GNN/action_models.py
(Tango) to specify the name. They are specified here for reference.
MODEL_NAME | Name in paper |
---|---|
GGCN_Auto_Action | GGCN+Auto (Baseline) |
GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Action | Tango |
Final_GGCN_Action | - GGCN |
Final_Metric_Action | - Metric |
Final_Attn_Action | - Goal-Conditioned Attn |
Final_Cons_Action | - Constraints |
Final_C_Action | - ConceptNet |
Final_Auto_Action | - Autoregression |
Final_Aseq_Action | - Temporal Action History |
Final_L_Action | - Factored Likelihood |
For our ToolTango model, use GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Tool_Action.
EXEC_TYPE can be as follows:
EXEC_TYPE | Meaning |
---|---|
accuracy | Determine the action prediction accuracy of the given model on the Test set |
generalization | Calculate the plan execution accuracy of the given model on the GenTest set |
policy | Calculate the plan execution accuracy of the given model on the Test set |
To find the action prediction accuracy of trained "- GGCN" model:
python3 train.py home action Final_GGCN_Action accuracy
To find the plan execution accuracy on Test Set of trained "- GGCN" model:
python3 train.py home action Final_GGCN_Action policy
To find the plan execution accuracy on GenTest Set of trained "- GGCN" model:
python3 train.py home action Final_GGCN_Action generalization
We now describe how to replicate the below table given in the Tango paper.
To replicate the Action Prediction columns (action prediction accuracy), run the following commands for all models for both domains. The value in the table is the accuracy corresponding to the test set.
python3 train.py $DOMAIN action $MODEL_NAME accuracy
To replicate the Plan Execution columns (plan execution accuracy on the Test set), run the following commands for all models for both domains. The value in the table is the fraction of Correct
plans (the first number in the output triple).
python3 train.py $DOMAIN action $MODEL_NAME policy
To replicate the Generalizaiton Plan Execution Accuracy columns (plan execution accuracy on the GenTest set), run the following commands for all models for both domains. The value in the table corresponding to the $DOMAIN is the fraction of Correct
plans (the first number in the output triple).
python3 train.py $DOMAIN action $MODEL_NAME generalization
There are 9 test cases in the home domain and 8 in the factory domain. To obtain the table values corresponding to the different generalization types, take the average of the fraction of Correct
plans corresponding to the test cases in both home and factory domains. For example, to calculate the Position accuracy take the average of the three numbers: fraction of correct plans in testcases 1 and 2 in home and testcase 8 in factory. Refer the table below for other generalization types.
Generalization Type | Testcases in Home domain | Testcases in Factory domain |
---|---|---|
Position | 1,2 | 8 |
Alternate | 3,5,8 | 1,2,5 |
Unseen | 4,9 | 4,6 |
Robust | 7 | 3 |
Goal | 6 | 7.8 |
BSD-2-Clause. Copyright (c) 2020, Shreshth Tuli, Rajas Basal, Rohan Paul and Mausam. All rights reserved.