Skip to content

This repository is the official implementation of Improving Object-centric Learning With Query Optimization

Notifications You must be signed in to change notification settings

YuLiu-LY/BO-QSA

Repository files navigation

BO-QSA

Paper arXiv Paper PDF Project Page

This repository contains the official implementation of the ICLR 2023 paper:

Improving Object-centric Learning With Query Optimization

Baoxiong Jia*, YuLiu*, Siyuan Huang

Environment Setup

We provide all environment configurations in requirements.txt. To install all packages, you can create a conda environment and install the packages as follows:

conda create -n BO-QSA python=3.8
conda activate BO-QSA
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
pip install -r requirements.txt

In our experiments, we used NVIDIA CUDA 11.3 on Ubuntu 20.04. Similar CUDA version should also be acceptable with corresponding version control for torch and torchvision.

Dataset

1. ShapeStacks, ObjectsRoom, CLEVRTex, Flowers

Download ShapeStacks, ObjectsRoom, CLEVRTex and Flowers datasets with

chmod +x scripts/downloads_data.sh
./downloads_data.sh

For ObjectsRoom dataset, you need to run objectsroom_process.py to save the tfrecords dataset as a png format. Remember to change the DATA_ROOT in downloads_data.sh and objectsroom_process.py to your own paths.

2. PTR, Birds, Dogs, Cars

Download PTR dataset following instructions from http://ptr.csail.mit.edu. Download CUB-Birds, Stanford Dogs, and Cars datasets from here, provided by authors from DRC. We use the birds.zip, cars.tar and dogs.zip and then uncompress them.

4. YCB, ScanNet, COCO

YCB, ScanNet and COCO datasets are available from here, provided by authors from UnsupObjSeg.

5. Data preparation

Please organize the data following here before experiments.

Training

To train the model from scratch we provide the following model files:

  • train_trans_dec.py: transformer-based model
  • train_mixture_dec.py: mixture-based model
  • train_base_sa.py: original slot-attention We provide training scripts under scripts/train. Please use the following command and change .sh file to the model you want to experiment with. Take the transformer-based decoder experiment on Birds as an exmaple, you can run the following:
$ cd scripts
$ cd train
$ chmod +x trans_dec_birds.sh
$ ./trans_dec_birds.sh

Remember to change the paths in path.json to your own paths.

Reloading checkpoints & Evaluation

To reload checkpoints and only run inference, we provide the following model files:

  • test_trans_dec.py: transformer-based model
  • test_mixture_dec.py: mixture-based model
  • test_base_sa.py: original slot-attention

Similarly, we provide testing scripts under scripts/test. We provide transformer-based model for real-world datasets (Birds, Dogs, Cars, Flowers, YCB, ScanNet, COCO) and mixture-based model for synthetic datasets(ShapeStacks, ObjectsRoom, ClevrTex, PTR). We provide all checkpoints here. Please use the following command and change .sh file to the model you want to experiment with:

$ cd scripts
$ cd test
$ chmod +x trans_dec_birds.sh
$ ./trans_dec_birds.sh

Citation

If you find our paper and/or code helpful, please consider citing:

@inproceedings{jia2023improving,
  title={Improving Object-centric Learning with Query Optimization},
  author={Jia, Baoxiong and Liu, Yu and Huang, Siyuan},
  booktitle={The Eleventh International Conference on Learning Representations},
  year={2023}
}

Acknowledgement

This code heavily used resources from SLATE, SlotAttention, GENESISv2, DRC, Multi-Object Datasets, shapestacks. We thank the authors for open-sourcing their awesome projects.

About

This repository is the official implementation of Improving Object-centric Learning With Query Optimization

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published