Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training state of ResNet coupled with mutable batch_stats collection #9

Open
cgarciae opened this issue Aug 31, 2022 · 1 comment
Open

Comments

@cgarciae
Copy link

cgarciae commented Aug 31, 2022

Hey @n2cholas!

This is not an immediate issue but I was playing around with jax_resnet and noticed that ConvBlock decides if it should update it batch statistics or not depending on whether the batch_stats collection is mutable or not. This initially sounds like a safe bet but if you embed ResNet inside a another module that by chance also uses BatchNorm and you want to train the other module but freeze ResNet, it is not clear how you would do this.

mutable = self.is_mutable_collection('batch_stats')
x = self.norm_cls(use_running_average=not mutable, scale_init=scale_init)(x)

To solve this you have to:

  • Accept a use_running_average (or equivalent) argument in ConvBlock.__call__ and pass it to norm_cls.
  • Refactor ResNet to be a custom Module (instead of Sequential) so you also accept this in __call__ and pass it around to the relevant submodules that expect it.

Some repos use a single train flag to determine the state of both BatchNorm and Dropout.

Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.

@n2cholas
Copy link
Owner

n2cholas commented Sep 3, 2022

Thanks for raising this @cgarciae, definitely a relevant use case. I would prefer having a use_running_average member variable in ConvBlock. Perhaps in the future we can add a use_running_average=None argument in ConvBlock.__call__ if there is sufficient demand, then use nn.merge_param just like Flax does, but my general preference is to configure the behaviour of the module during construction (with @nn.compact you do both at once anyway).

Would be amazing if you could open a PR. Let me know if you have any issues with the environment/tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants