Skip to content

Justin-Tan/invariant_reps

Repository files navigation

Invariant Representations

Learning invariant representations with mutual information regularization.

Usage

The code depends on Tensorflow 1.13

# Check command line arguments
$ python3 train.py -h
# Run, e.g.
$ python3 train.py -i /my/training/data -test /my/testing/data --name my_model -lambda 10 -MI -kl

To enable adversarial training mode based on a variant of the method proposed in Louppe et. al., use adv_train.py in place of train.py and enable use_adverary = True in the config file.

Regularization Methods

This method is essentially based around adding penalties to the objective function that penalize some sort of divergence between the joint distribution of an intermediate representation of the data found by passing the data through a neural network, and the sensitive variables. See the presentation below for further details. There are several different penalties implemented, the recommended one that has been found to be the most effective is the kl_update penalty that is analogous to the generator update rule proposed in arXiv:1610.04490. You can also enable adversarial training, proposed in arXiv:1611.01046 to attempt decorrelation.

Extensions

The network architecture is kept modular from the remainder of the computational graph. For ease of experimentation, the codebase will support any arbitrary architecture that yields logits in the context of binary classification. In addition, the adversarial training procedure can interface with any arbitrary network architecture. To swap out the network for your custom one, create a @staticmethod under the Network class in network.py:

@staticmethod
def my_network(x, config, **kwargs):
    """
    Inputs:
    x: example data
    config: class defining hyperparameter values

    Returns:
    network logits
    """

    # To prevent overfitting, we don't even look at the inputs!
    return tf.random_normal([x.shape[0], config.n_classes], seed=42)

Now open model.py and edit one of the first lines under the Model init:

class Model():
    def __init__(self, **kwargs):

        arch = Network.my_network
        # The rest of computational graph defined here
        # (You shouldn't need to do anything else)

Monitoring / checkpoints

Tensorboard summaries are written periodically to tensorboard/ and checkpoints are saved to checkpoints/ every epoch.

Dependencies

Resources / Related Work

Future Work

  • Increase stability of MI estimator.
  • Include HEP example.

Contact

Feel free to open an issue or ping [firstname.lastname@coepp.org.au] for any questions.

About

Learning invariant representations with mutual information regularization

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published