Skip to content

Training your own models

Shreshth Tuli edited this page Jul 11, 2020 · 9 revisions

This page gives a tutorial on how to train the models mentioned in the Tango paper.

Prerequisite: Follow the instructions on the Environment Setup Page.

Pre-trained models: All the trained models mentioned in the Tango paper can be found here.

Model Training

All the models mentioned in the paper can be trained through the command

$ python3 train.py $DOMAIN action $MODEL_NAME train

Here DOMAIN can be home/factory.

MODEL_NAME specifies the specific PyTorch model that you want to train. Look at src/GNN/models.py (ToolNet) or src/GNN/action_models.py (Tango) to specify the name. They are specified here for reference.

MODEL_NAME Name in paper
GGCN_Auto_Action GGCN+Auto (Baseline)
GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Action Tango
Final_GGCN_Action - GGCN
Final_Attn_Action - Goal-Conditioned Attn
Final_Cons_Action - Constraints
Final_Auto_Action - Autoregression
Final_Aseq_Action - Temporal Action History
Final_L_Action - Factored Likelihood

This command will train MODEL_NAME on the training dataset for NUM_EPOCHS epochs specified in src/GNN/CONSTANTS.py. It will save a checkpoint file trained_models/DOMAIN/MODEL_NAME_EPOCH.ckpt after the EPOCH epoch. In the end, it will output the epoch (say N) corresponding to the maximum policy accuracy using early stopping criteria. Rename the trained_models/DOMAIN/MODEL_NAME_N.ckpt file to trained_models/DOMAIN/MODEL_NAME_Trained.ckpt for testing. You may delete the other checkpoint files.

Sample Commands

To train the best model in home domain:

python3 train.py home action GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Action train

To train the best model in factory domain:

python3 train.py factory action GGCN_Metric_Attn_Aseq_L_Auto_Cons_C_Action train

To train the ablated "- GGCN" model in home domain:

python3 train.py home action Final_GGCN_Action train

In case of queries, please contact Shreshth Tuli at shreshthtuli@gmail.com

Clone this wiki locally