Skip to content

rs9000/VisualReasoning_MMnet

Repository files navigation

VisualReasoning_MMnet

Models in Pytorch for visual reasoning task on Clevr dataset.

Stack attention:
https://arxiv.org/pdf/1511.02274.pdf

Module network:
https://arxiv.org/pdf/1705.03633.pdf

Yes, but what's new?
Try to archive same performance in end-to-end differentiable architecture:
Module memory network [new]
Module memory network end2end differentiable [new]

Try to archive weak supervision:
(Work in progress)

Set-up

Step 1: Download the data

mkdir data
wget https://s3-us-west-1.amazonaws.com/clevr/CLEVR_v1.0.zip -O data/CLEVR_v1.0.zip
unzip data/CLEVR_v1.0.zip -d data

Step 2: Extract Image Features

python scripts/extract_features.py \
  --input_image_dir data/CLEVR_v1.0/images/train \
  --output_h5_file data/train_features.h5

Step 3: Preprocess Questions

python scripts/preprocess_questions.py \
  --input_questions_json data/CLEVR_v1.0/questions/CLEVR_train_questions.json \
  --output_h5_file data/train_questions.h5 \
  --output_vocab_json data/vocab.json

Test sample

Train

python train.py [-args]

arguments:
  --model               Model to train: SAN, SAN_wbw, PG, PG_memory, PG_endtoend
  --question_size       Number of words in question dictionary
  --stem-dim            Number of feature-maps
  --n-channel           Number of features channels
  --batch_size          Mini-batch dim
  --min_grad            Minimum value of gradient clipping
  --max_grad            Maximum value of gradient clipping
  --load_model_path     Load pre-trained model (path)
  --load_model_mode     Load model mode: Execution engine (EE), Program Generator (PG), Both (PG+EE)
  --save_model          Save model ? (bool)
  --clevr_dataset       Clevr dataset data (path)
  --clevr_val_images    Clevr dataset validation images (path)
  --num_iterations      Num iteration per epoch
  --num_val_samples     Number validation samples
  --batch_multiplier    Virtual batch (minimum value: 1)
  --train_mode          Train mode:  Execution engine (EE), Program Generator (PG), Both (PG+EE)
  --decoder_mode        Progam generator mode: Backpropagation (soft, gumbel) Reinforce (hard, hard+penalty)
  --use_curriculum      Use curriculum to train program generator (bool) 

Module memory network (Pg_memory)


Module memory network end2end (Pg_endtoend)

Models

Stack Attention (SAN)

Stack Attention word2word (SAN_wbw)

Module Network (PG)

Module-Memory Network (PG_memory)

Module-Memory Network end2end (PG_endtoend)

Releases

No releases published

Packages

No packages published

Languages