Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making the training process generic for custom models #8

Open
shivamsaboo17 opened this issue Aug 30, 2019 · 19 comments
Open

Making the training process generic for custom models #8

shivamsaboo17 opened this issue Aug 30, 2019 · 19 comments

Comments

@shivamsaboo17
Copy link

Hi! Great paper! I implemented manifold mixup and also support for interpolated adversarial training (https://github.com/shivamsaboo17/ManifoldMixup) for any custom model defined by user using PyTorch's forward hook functionality by:

  1. Select a random index and apply forward hook to that layer
  2. Forward pass using data input x_0 and record output at hooked layer
  3. Use this output along with new input x_1 by adding new hook at the same layer to do this mixup operation

For now I am selecting the layer randomly without considering type of layer (batchnorm, relu etc are counted as different layer), hence I wanted to know if there should be any layer selection rule such as 'mixup should be done only after a conv block in resnet' and if yes how to extend this rule to custom models that users might build?

@alexmlamb
Copy link
Collaborator

What we always did was do the mixing directly following a block of residual layers. We actually just did this for convenience and it has no strong motivation.

One thing to consider is that if you do mixing "within" a residual layer, it could have some odd properties, for example:

h[t+1] = relu( h[t] + mixing(W1*relu(W2)) )

I think that might be a bad idea, since you're only mixing one of the residual layer's updated, but it's not something I've thought about very deeply.

I think it's worth studying though. There are other cases like networks that have multiple branches or recurrent neural networks, where "picking a layer to mix" is more ambiguous, and I think there's value in understanding it better.

@shivamsaboo17
Copy link
Author

I get your point. It would be interesting to somehow identify the layer for each input batch perhaps based on sensitivity to loss or something like that before actually mixing.
For now, I will think about a way to identify user-defined blocks and add an option where you simply provide valid indexes as input. Let me know if you have any suggestions!

@alexmlamb
Copy link
Collaborator

So another practical thing to consider is that we usually set the choice of layers as: {'input', 'output of first resblock', 'output of second resblock'}. It's also notable that you don't really want to mix too close to the output, although it doesn't hurt that much. (our paper has an ablation study on this in a table).

I'm thinking more about how your code works and I'll try to play around with it. One thing I'm curious about is how it behaves when a module is defined once and called multiple times in the forward pass. For example if I say "self.relu = nn.ReLU()" and then call that after every layer in the forward pass. Pytorch still just sees it as one module, right?

Your overall implementation is interesting because it looks like you run two batches instead of one (and then mixing with the same batch). They're not exactly the same but of course are pretty similar.

One related situation I'm interested in is training with manifold mixup with a batch size of 1 (this is common where you have very large examples, for example really high resolution images) and I think it could be approached using your code base.

@shivamsaboo17
Copy link
Author

So another practical thing to consider is that we usually set the choice of layers as: {'input', 'output of first resblock', 'output of second resblock'}. It's also notable that you don't really want to mix too close to the output, although it doesn't hurt that much. (our paper has an ablation study on this in a table).

That's interesting. Regarding mixing too close to output, I will look at ablation study in more detail. Although I am curious to know why this happens because in the extreme case (at logits level) it is equivalent to a linear combination of probabilities which is perhaps the best approximation of what our output should be ?(correct me if I am wrong).

I'm thinking more about how your code works and I'll try to play around with it. One thing I'm curious about is how it behaves when a module is defined once and called multiple times in the forward pass. For example if I say "self.relu = nn.ReLU()" and then call that after every layer in the forward pass. Pytorch still just sees it as one module, right?

Great catch! Didn't consider this. I think for fetch_hook, it will just override the result until the last instance of such a module but when modify_hook will be called it should throw an error if shapes don't match (if combining intermediate output from the last instance of the module and first instance for new input). Will try to replicate this scenario.

@alexmlamb
Copy link
Collaborator

So I thought about it a bit more and I have a few thoughts:

  1. You'll want to pick the layers to mix at prior to runtime, because you need to set the probability to mix layers at.

  2. Maybe the right structure would be to have a "Mixup wrapper" that gets passed into the model's init, and that the user needs to manually wrap around each module where mixup will be performed on that module's output.

  3. For example, in this very standard Pytorch resnet definition, I was thinking that the original:

https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py#L72

    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)

... becomes...

    def __init__(self, block, num_blocks, num_classes=10, MixWrapper):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = MixWrapper(self._make_layer(block, 64, num_blocks[0], stride=1))
        self.layer2 = MixWrapper(self._make_layer(block, 128, num_blocks[1], stride=2))
        self.layer3 = MixWrapper(self._make_layer(block, 256, num_blocks[2], stride=2))
        self.layer4 = MixWrapper(self._make_layer(block, 512, num_blocks[3], stride=2))
self.linear = nn.Linear(512*block.expansion, num_classes)
  1. This MixWrapper would make a new module which you could then apply the hook to. The mixwrapper would just call whatever its input is with whatever arguments, and then do the mixing, then return that output.

  2. I think this would be alright, although it does have the downside that it would no longer be fully model-agnostic, but I don't see any nice way around that.

@shivamsaboo17
Copy link
Author

Sounds like a good idea. Will implement this wrapper and let you know!
Meanwhile, I still have this question about mixing too close to output, because in the extreme case (at logits level) it is equivalent to a linear combination of probabilities which is perhaps the best approximation of what our output should be?

@shivamsaboo17
Copy link
Author

I also don't think this solves the problem if the module is defined once and used multiple times in forward. For now, I am thinking about raising a warning and using only the first instance for adding the hooks.

@alexmlamb
Copy link
Collaborator

Meanwhile, I still have this question about mixing too close to output, because in the extreme case (at logits level) it is equivalent to a linear combination of probabilities which is perhaps the best approximation of what our output should be?

Right so I was thinking of mixing right before the final softmax, in which case it would be impossible for it to be linear due to the presence of the softmax. If you mixed after the softmax I think it wouldn't do anything but I also haven't checked the math.

I also don't think this solves the problem if the module is defined once and used multiple times in forward. For now, I am thinking about raising a warning and using only the first instance for adding the hooks.

This makes a lot of sense to me. I think if you let the user control where to mix and then warn them if the mixing thing gets called twice, that seems like a pretty good solution.

@shivamsaboo17
Copy link
Author

shivamsaboo17 commented Sep 4, 2019

Hi. Sorry for the delay. But I made the changes we discussed (changes descibed in my readme). I kept mixup_all=True/False as an argument in case someone does not have access to source code of model (just weights) hence in this case we can allow using all modules for the mixup.
Let me know if you find any inconsistency and thanks for suggestions!

@alexmlamb
Copy link
Collaborator

Looks good to me, although maybe there's still a more elegant way to do it.

I think it will work on the models defined here:

https://github.com/kuangliu/pytorch-cifar

And it's also worth noting that someone could add the mixup modules outside of the model's init function, for example by saying:

model.layer1 = MixupModule(model.layer1)
model.layer2 = MixupModule(model.layer2)


If it's tested reasonably well, maybe it's possible to try to add it into Pytorch?

@shivamsaboo17
Copy link
Author

shivamsaboo17 commented Sep 4, 2019

Adding to PyTorch sounds like a great idea! First I will test my code on cifar 10 models you mention above (might take some time as I have exams coming up) and will ping here with results.

@alexmlamb
Copy link
Collaborator

Okay cool. The only thing is that Manifold Mixup sometimes requires more epochs to get the best improvement.

When doing this it might also be worth trying your "mix in all modules" variant to see how it does. Could be interesting.

@alexmlamb
Copy link
Collaborator

Any progress on this? I'd be happy to keep working on getting this working :)

@shivamsaboo17
Copy link
Author

Hi Alex. Unfortunately I wasn't able to make much progress due to exams. However I will run the experiments and report the results here as soon as I can :)

@nestordemeure
Copy link

nestordemeure commented Feb 2, 2020

Here is a fastai compatible implementation inspired by @shivamsaboo17's implementation with a slightly different API (I give tree different ways to select modules) and improved performances (as I avoid running the model twice per batch).

It makes using manifold mixup with fastai trivial (and as performant as using input mixup) which might help democratizing the technique :)

@alexmlamb
Copy link
Collaborator

Awesome! I'm looking forward to trying it out after the ICML deadline.

It would also be good to maybe try to reproduce Manifold Mixup's CIFAR-10 results (or at least show an improvement over input mixup).

@alexmlamb
Copy link
Collaborator

One more thing is that for Manifold Mixup, I think it's important to mix after a full layer or after a full-resblock. I think if the model has a skip connection but you don't mix in the values passed through the skip connection it will mess things up.

I'm not sure if there's any elegant way to do this automatically - aside from perhaps just making people run a manifold mixup callback thing inside their model code.

@christopher-beckham
Copy link
Collaborator

christopher-beckham commented Feb 5, 2020 via email

@nestordemeure
Copy link

nestordemeure commented Feb 5, 2020

In practice, we have access to the list of all modules (in order) so it is easy to say that some modules types should not be used or to exclude the first module from the list (to skip at least on layer).

The problem with skip connections might explain the disappointing results I had with a resnet when I sampled from all layers. Sampling only the ouputs of the resBlocks solved the problem and finally produced improvement over input mixup.

I currently implements the following rule:

  • if the user passes a list of modules, use it and nothing else
  • else if there are ManifoldMixup modules then use only those (the user can use them to indicate place suitable for mixup)
  • else if there are ResBlocks, then use only those (similarly to the mixup paper)
  • else use all non-recurent layers

Is there some other criteria I might be missing ?

By the way, applying the mixup on the latest layer only gave me even better result (which, for classification, makes sense to me but was still interesting to observe).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants