Skip to content

thejonathanvancetrance/MDAN

 
 

Repository files navigation

MDAN: Multiple Source Domain Adaptation with Adversarial Learning for Alfalfa

How to setup a VM that will run MDAN demo as of 3/3/2021, you should install miniconda first if you haven't already:

  1. create a conda virtual environment using command: conda create -n myenv python=3.6.6 scipy numpy (change myenv to mdan or something)
  2. activate your new environment: conda activate myenv
  3. remove other python installation: conda remove python 3.6.8
  4. install pytorch: conda install -c pytorch pytorch=1.0.0
  5. good luck! I hope I got that right, it's at least very close!
  6. now cd to the myenv directory and run: git clone https://github.com/thejonathanvancetrance/MDAN
  7. cd to myenv/MDAN and run the deom:
python main_amazon.py -o [maxmin|dynamic]

This project is forked from:

PyTorch demo code for paper Multiple Source Domain Adaptation with Adversarial Learning and Adversarial Multiple Source Domain Adaptation by Han Zhao, Shanghang Zhang, Guanhang Wu, João Costeira, José Moura and Geoff Gordon.

Summary

MDAN is a method for domain adaptation with multiple sources. Specifically, during training, a set of $k$ domains, represented by $k$ labeled source datasets, together with one unlabeled target dataset, are used to train the model jointly. A schematic representation of the overall model during the training phase is shown in the following figure:

Essentially, MDAN contains three components:

  • A feature extractor, parametrized by a neural network.
  • A hypothesis, classifier/regressor, for the desired task.
  • A set of $k$ domain classifier, where each corresponds to a binary classifier that tries to discriminate between the pair $(S_i, T)$.

At a high level, in each iteration:

  1. Feature extractor + hypothesis try to learn informative representation and decision boundary that have good generalization on the $k$ source tasks (because we have label).
  2. Feature extractor + domain classifiers form a two-player zero-sum game, where the feature extractor tries to learn domain invariant representation and the domain classifier tries to distinguish whether the given sample is from source domain or target domain. Note that each domain classifier is only responsible for one specific domain classification task, i.e., for the pair $(S_i, T)$.

Since we have $k$ domain classifiers, to define an overall reward for the set of $k$ domain classifiers, we develop two variants of MDAN:

  • Hard-Max MDAN: The overall reward/error is defined to be the minimum classification error rate (Left part of the red box in the above figure).
  • Soft-Max MDAN: The overall reward/error is defined to be the $\log\sum\exp(\cdot)$ of all the $k$ domain classification errors (Right part of the red box in the figure).

More detailed description about these two variants could be found in Section 4 of the paper Adversarial Multiple Source Domain Adaptation.

Optimization

It is notoriously hard to optimize minimax problem when it is nonconvex. Our goal is to converge to a saddle point. In this code repo we use the double gradient descent method, e.g., the primal-dual gradient method, to optimize the objective function. Intuitively, this means that we use simultaneous gradient updates for all the components in the model. As a comparison, in block coordinate method, we would either fix the set of $k$ domain classifiers or the feature extractor and the hypothesis, and optimize the other until convergence, and then iterate from there.

Specifically, we use the well-known gradient reversal layer to implement this method. Code snippet in PyTorch shown as follows:

class GradientReversalLayer(torch.autograd.Function):
    """
    Implement the gradient reversal layer for the convenience of domain adaptation neural network.
    The forward part is the identity function while the backward part is the negative function.
    """
    def forward(self, inputs):
        return inputs

    def backward(self, grad_output):
        grad_input = grad_output.clone()
        grad_input = -grad_input
        return grad_input

Prerequisites

  • Python 3.6.6
  • PyTorch >= 1.0.0
  • Numpy
  • Scipy

This part explains how to reproduce the Amazon sentiment analysis experiment in the paper.

Training + Evaluation

Run

python main_amazon.py -o [maxmin|dynamic]

Here maxmin corresponds to the Hard-Max variant and dynamic corresponds to the Soft-Max variant.

Several practical suggestions on training these models:

  • The adaptation performance depends on the --mu hyperparameter that corresponds to the coefficient for the domain adversarial loss. This hyperparameter is dataset dependent and should be chosen appropriately for different datasets.
  • One may not want to set the number of training epochs to be too large. Theoretically, this could hurt the adaptation performance when the label distributions between source domain and target domain are significantly different. This phenomenon is observed in practice and further explained in our recent paper On Learning Invariant Representation for Domain Adaptation .

Citation

If you use this code for your research and find it helpful, please cite our paper Multiple Source Domain Adaptation with Adversarial Learning or Adversarial Multiple Source Domain Adaptation:

@article{zhao2018multiple,
  title={Multiple source domain adaptation with adversarial learning},
  author={Zhao, Han and Zhang, Shanghang and Wu, Guanhang and Moura, Jos{\'e} MF and Costeira, Joao P and Gordon, Geoffrey J},
  booktitle={International Conference on Learning Representations, workshop track},
  year={2018}
}

or

@inproceedings{zhao2018adversarial,
  title={Adversarial multiple source domain adaptation},
  author={Zhao, Han and Zhang, Shanghang and Wu, Guanhang and Moura, Jos{\'e} MF and Costeira, Joao P and Gordon, Geoffrey J},
  booktitle={Advances in Neural Information Processing Systems},
  pages={8568--8579},
  year={2018}
}

Contact

Please email to han.zhao@cs.cmu.edu should you have any questions, comments or suggestions.

About

Demo code for the MDAN paper.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.3%
  • Jupyter Notebook 2.7%