Skip to content

Replicating results

Shreshth Tuli edited this page Oct 27, 2021 · 20 revisions

This page explains how to replicate results in the Tango paper.

Prerequisite: Follow the instructions on the Environment Setup Page. You should also have all trained models stored in ./trained_models/home/ for home domain and ./trained_models/factory/ for factory domain. You may train your models by following this wiki or download pre-trained models from here.

Evaluation Metrics

We use two evaluation metrics:

  1. Action Prediction Accuracy: This is the fraction of tool interactions predicted by the model, which matched the human demonstrated action a for a given state s.
  2. Plan Execution Accuracy: This is the fraction of estimated plans that are successful, i.e., can be executed by the robot in simulation and attain the intended goal (with an upper bound of 50 actions in the plan).

These two metrics are calculated on the Test set and Generalization Test (GenTest) set.

Determine the accuracy of a model

Accuracy of any model can be obtained using the following command:

$ python3 train.py $DOMAIN action $MODEL_NAME $EXEC_TYPE

This command looks for the file trained_models/DOMAIN/MODEL_NAME_Trained.ckpt if you are using a pre-trained model, so make sure that the file corresponding to this command is present.

Here DOMAIN can be home/factory.

MODEL_NAME specifies the specific PyTorch model that you want to train. Look at src/GNN/models.py (ToolNet) or src/GNN/action_models.py (Tango) to specify the name. They are specified here for reference.

MODEL_NAME Name in paper
GGCN_Auto_Action GGCN+Auto (Baseline)
GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Action Tango
Final_GGCN_Action - GGCN
Final_Metric_Action - Metric
Final_Attn_Action - Goal-Conditioned Attn
Final_Cons_Action - Constraints
Final_C_Action - ConceptNet
Final_Auto_Action - Autoregression
Final_Aseq_Action - Temporal Action History
Final_L_Action - Factored Likelihood

For our ToolTango model, use GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Tool_Action.

EXEC_TYPE can be as follows:

EXEC_TYPE Meaning
accuracy Determine the action prediction accuracy of the given model on the Test set
generalization Calculate the plan execution accuracy of the given model on the GenTest set
policy Calculate the plan execution accuracy of the given model on the Test set

Sample Commands

To find the action prediction accuracy of trained "- GGCN" model:

python3 train.py home action Final_GGCN_Action accuracy

To find the plan execution accuracy on Test Set of trained "- GGCN" model:

python3 train.py home action Final_GGCN_Action policy

To find the plan execution accuracy on GenTest Set of trained "- GGCN" model:

python3 train.py home action Final_GGCN_Action generalization

Replicating the results table

We now describe how to replicate the below table given in the Tango paper.

To replicate the Action Prediction columns (action prediction accuracy), run the following commands for all models for both domains. The value in the table is the accuracy corresponding to the test set.

python3 train.py $DOMAIN action $MODEL_NAME accuracy

To replicate the Plan Execution columns (plan execution accuracy on the Test set), run the following commands for all models for both domains. The value in the table is the fraction of Correct plans (the first number in the output triple).

python3 train.py $DOMAIN action $MODEL_NAME policy

To replicate the Generalizaiton Plan Execution Accuracy columns (plan execution accuracy on the GenTest set), run the following commands for all models for both domains. The value in the table corresponding to the $DOMAIN is the fraction of Correct plans (the first number in the output triple).

python3 train.py $DOMAIN action $MODEL_NAME generalization

There are 9 test cases in the home domain and 8 in the factory domain. To obtain the table values corresponding to the different generalization types, take the average of the fraction of Correct plans corresponding to the test cases in both home and factory domains. For example, to calculate the Position accuracy take the average of the three numbers: fraction of correct plans in testcases 1 and 2 in home and testcase 8 in factory. Refer the table below for other generalization types.

Generalization Type Testcases in Home domain Testcases in Factory domain
Position 1,2 8
Alternate 3,5,8 1,2,5
Unseen 4,9 4,6
Robust 7 3
Goal 6 7.8

In case of queries, please contact Shreshth Tuli at [email protected]

Clone this wiki locally