Skip to content

Tsinghua-MARS-Lab/M2I

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

M2I: From Factored Marginal Trajectory Prediction to Interactive Prediction

Paper Project Page

Overview

This repository can be used for reproducing the result we published in our paper. We are constantly working on modeling interactions for road users and planning/prediction-focused simulations. Follow us on Google Scholar for our latest works on these interesting topics.

What you can expect in this repository 🤟:

  • Code to predict with pre-trained relation predictor
  • Code to predict with pre-trained marginal trajectory predictor
  • Code to predict with pre-trained conditional trajectory predictor
  • Code to train the relation predictor on WOMD
  • Code to train the marginal trajectory predictor on WOMD
  • Code to train the conditional trajectory predictor on WOMD

What you might not find in this repository 😵:

  • Code to generate the ground truth label, but you can download the labels we used to train our relation predictor
  • The simulator we used for visualization, but you can use the visualization toolkits provided from Waymo as an alternative
  • Code to pack and submit the prediction results to the Waymo Interactive Motion Prediction Challenge

Prepare Your Dataset

Login and download the dataset from the Waymo Open Dataset. We used the tf.Example proto file from the interactive validation/testing dataset.

As only a subset of scenarios in the training dataset has annotated interactive agents, we provide a script to filter out interactive scenarios from training data. Details are provided below.

Cython

Our run.py script includes the code to compile the Cython script utils_cython.pyx on each run. In case that does not work out on your machine, use the following command to compile instead.

cd src
cython -a utils_cython.pyx && python setup.py build_ext --inplace

Quick Start

Requires:

  • Python 3.6
  • PyTorch 1.6+

Install packages into a Conda environment (Cython, tensorflow, waymo-open-dataset, etc.):

conda env create -f conda.cuda111.yaml
conda activate M2I

Download our prediction results and our pre-trained models from Google Drive, then you can run the following commands for a quick prediction.

Relation Prediction

Download our relation prediction results m2i.relation.v2v.VAL on the interactive validation dataset, or download the ground truth label validation_interactive_gt_relations.pickle and the pre-trained relation model m2i.relation.v2v.zip (unpack first) to run the following command to predict. Download and unpack to the project folder to load.

OUTPUT_DIR=m2i.relation.v2v; \
DATA_DIR=./validation_interactive/; \
RELATION_GT_DIR=./validation_interactive_gt_relations.pickle; \
python -m src.run --waymo --data_dir ${DATA_DIR} \
--config relation.yaml --output_dir ${OUTPUT_DIR} \
--future_frame_num 80 \
--relation_file_path ${RELATION_GT_DIR} --agent_type vehicle \
--distributed_training 1 -e --nms_threshold 7.2 \
--validation_model 25 --relation_pred_threshold 0.9

Marginal Trajectory Prediction

Download the marginal prediction results validation_interactive_m2i_v.pickle on the interactive validation dataset, or run the following command to predict with the pre-trained marginal prediction model densetnt.raster.vehicle.1.zip (unpack first):

OUTPUT_DIR=densetnt.raster.vehicle.1; \
DATA_DIR=./validation_interactive/; \
python -m src.run --do_train --waymo --data_dir ${DATA_DIR} \
--output_dir ${OUTPUT_DIR} --hidden_size 128 --train_batch_size 64 \
--sub_graph_batch_size 4096 --core_num 16 \
--other_params l1_loss densetnt goals_2D enhance_global_graph laneGCN point_sub_graph laneGCN-4 stride_10_2 raster train_pair_interest save_rst \
--dist 1 --future_frame_num 80 --agent_type vehicle -e --nms 7.2 --eval_exp_path validation_interactive_v_rdensetnt_full

Conditional Trajectory Prediction

Download our pre-trained conditional prediction model m2i.conditional.v2v.zip and unpack it to predict trajectories of the reactors by running:

OUTPUT_DIR=m2i.conditional.v2v; \
DATA_DIR=./validation_interactive/; \
RELATION_GT_DIR=./validation_interactive_gt_relations.pickle; \
RELATION_PRED_DIR=./m2i.relation.v2v.VAL; \
INFLUENCER_PRED_DIR=./validation_interactive_m2i_v.pickle; \
python -m src.run --waymo --data_dir ${DATA_DIR} \
--output_dir ${OUTPUT_DIR} --config conditional_pred.yaml \
--relation_file_path ${RELATION_GT_DIR} \
--relation_pred_file_path ${RELATION_PRED_DIR} \
--influencer_pred_file_path ${INFLUENCER_PRED_DIR} \
--future_frame_num 80 \
-e --eval_rst_saving_number 0 \
--eval_exp_path ${RESULT_EXPORT_PATH}

The file validation_interactive_gt_relations.pickle, m2i.relation.v2v.VAL, validation_interactive_m2i_v.pickle can be found in the Google drive.

This command will output 6 predictions conditioned on one influencer prediction based on --eval_rst_saving_number. Change this variable from 0 to 5 to get 6 groups of conditional predictions.

Performance

Results of this pre-trained model on the Waymo Open Motion Dataset interactive prediction benchmark:

Set Type minFDE MR mAP
Validation (8s) Vehicle 5.49 0.55 0.18
Pedstrian 3.61 0.60 0.06
Cyclist 6.26 0.73 0.04
Test (8s) Vehicle 5.65 0.57 0.16
Pedstrian 3.73 0.60 0.06
Cyclist 6.16 0.74 0.03

Training

Filtering interactive training data

Waymo Open Dataset does not provide a separate interactive training data partition. We provide a script to filter out interactive data from the training set:

python scripts/filter_interactive_data.py -i TRAINING_DATA_DIR -o TRAINING_INTERACTIVE_DATA_DIR

Training Relation Predictor

Download the ground truth relation data training_interactive_gt_relations.pickle from Google drive and run the following command to train a relation predictor:

DATA_DIR=./training_interactive/; \
RELATION_GT_DIR=./training_interactive_gt_relations.pickle; \
python -m src.run --do_train --waymo --data_dir ${DATA_DIR} \
--output_dir ${OUTPUT_DIR} --hidden_size 128 --train_batch_size 16 --sub_graph_batch_size 1024  --core_num 16 \
--future_frame_num 80 \
--relation_file_path ${RELATION_GT_DIR} --weight_decay 0.3 --agent_type vehicle \
--other_params train_relation pair_vv \
l1_loss densetnt goals_2D enhance_global_graph laneGCN point_sub_graph laneGCN-4 stride_10_2 raster \
--distributed_training 8

This command trains a relation predictor for vehicle-vehicle interactions. Replace pair_vv with pair_vc, pair_vp, pair_others to train for vehicle-cyclist, vehicle-pedestrian, and the rest type combinations, respectively. Change the parameter --distributed_training or --dist to the number of GPUs you have for training.

Training Marginal Predictor

Use the following command to train a marginal predictor:

OUTPUT_DIR=waymo.densetnt.raster.1; \
DATA_DIR=./training_interactive/; \
python src/run.py --do_train --waymo --data_dir ${DATA_DIR} \
--output_dir ${OUTPUT_DIR} --hidden_size 128 --train_batch_size 64 --sub_graph_batch_size 4096 --core_num 16 \
--other_params l1_loss densetnt goals_2D enhance_global_graph laneGCN point_sub_graph laneGCN-4 stride_10_2 raster \
--dist 8 --future_frame_num 80 --agent_type vehicle

This command trains a marginal predictor for the vehicle type. Change the parameter --agent_type to pedestrian or cyclist to train for other types' agents.

Training Conditional Predictor

Use the following command to train the conditional predictor with the ground truth influencer trajectory and relations.

DATA_DIR=./training_interactive/; \
RELATION_GT_DIR=./training_interactive_gt_relations.pickle; \
python -m src.run --do_train --waymo --data_dir ${DATA_DIR} \
--output_dir ${OUTPUT_DIR} --hidden_size 128 --train_batch_size 64 --sub_graph_batch_size 4096  --core_num 10 \
--future_frame_num 80 --agent_type vehicle \
--relation_file_path ${RELATION_GT_DIR} --weight_decay 0.3 \
--infMLP 8 --other_params train_reactor gt_relation_label gt_influencer_traj pair_vv raster_inf raster \
l1_loss densetnt goals_2D enhance_global_graph laneGCN point_sub_graph laneGCN-4 stride_10_2 \
--distributed_training 8

To filter the type of agents for reactors, change the value of --agent_type. To filter both the type of reactors and influencers, change the flag of pair_vv.

Access with License Agreement

Following the policy from Waymo, we might ask you to sign their 'Waymo Dataset License Agreement for Non-Commercial Use' by agreeing to this page before downloading our pre-trained model.

Citation

If you found this repo useful to your research, please consider citing

@inproceedings{sun2022m2i,
title={{M2I}: From Factored Marginal Trajectory Prediction to Interactive Prediction},
author={Sun, Qiao and Huang, Xin and Gu, Junru and Williams, Brian and Zhao, Hang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}

About

M2I is a simple but effective joint motion prediction framework through marginal and conditional predictions by exploiting the factorized relations between interacting agents.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •