Skip to content

lucasmansilla/DGvGS

Repository files navigation

Domain Generalization via Gradient Surgery

This repository contains the source code corresponding to the paper "Domain Generalization via Gradient Surgery" (ICCV 2021). You can check out our paper here: https://arxiv.org/abs/2108.01621.

Instructions

This project uses Python 3.8.10 and PyTorch 1.10.0.

Data:

  1. Download the PACS (Li et al., 2017), VLCS (Fang et al., 2013) and Office-Home (Venkateswara et al., 2017) datasets and put them in data/raw/.
  2. Resize images and generate training, validation and test splits. Run ./00_prepare_data.sh after installing the project environment (instructions below).

Project environment:

  1. Create and activate virtual environment: 1) python3 -m venv env, 2) source env/bin/activate
  2. Install required packages: pip install -r requirements.txt
  3. Install project modules (src): pip install -e .

Simulations:

To run simulations across all datasets (PACS, VLCS and Office-Home) and methods (Deep-All, Agr-Sum, Agr-Rand and PCGrad), execute ./01_run_trials.sh.

If you want to run a particular combination of dataset and method, use the train_model.py script. For example, the following instruction:

python scripts/train_model.py \
    --data_dir=data/processed \
    --results_dir=results/train \
    --dataset=PACS \
    --method=deep-all

will run Deep-All on PACS and save the results in results/train.

Reference

  • Mansilla, L., Echeveste, R., Milone, D. H., & Ferrante, E. (2021). Domain generalization via gradient surgery. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 6630-6638).

License

MIT