Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Information bottlenecks without warnings. #83

Open
Danfoa opened this issue Sep 17, 2023 · 0 comments
Open

Information bottlenecks without warnings. #83

Danfoa opened this issue Sep 17, 2023 · 0 comments

Comments

@Danfoa
Copy link
Contributor

Danfoa commented Sep 17, 2023

Hi @Gabri95,

I want to raise this issue, being 90% certain it is an issue, but please correct me if I need to be corrected.

By Shurs Lemma, and from what I can see from how you coded the basis managers, any mapping from an input irrep to a different output irrep is, by definition, set to zero mapping. This creates some information bottlenecks when using the nn.Linear module with an input field type containing irreps that are not present in the output field type irreps (or the other way around).

What I found troubling is that there is absolutely no warning about these information bottlenecks, which are, as far as I can notice, present in the mlp example of the repo. Let me explain this problem with your SO2MLP example:

In your example SO2MLP example you set the input field type to be the "standard" representation of SO2, which happens to be the irrep with frequency 1. This means that the input signal is defined to be present only in the irrep 1.

        # the input contains the coordinates of a point in the 3D space
        self.in_type = self.gspace.type(self.G.standard_representation())

Then for the next layers you define output representations band limited to frequencies higher than 1. First layer is band limited to 1 (irreps (0,) and (1,)). The mapping between input irrep (1,) and output irrep (0,) will always be zero.

        # Layer 1
        # We will use the regular representation of SO(2) acting on signals over SO(2) itself, bandlimited to frequency 1
        # Most of the comments on the previous SO(3) network apply here as well
       
        activation1 = nn.FourierELU(
            self.gspace,
            channels=3, # specify the number of signals in the output features
            irreps=self.G.bl_regular_representation(L=1).irreps, # include all frequencies up to L=1
            inplace=True,
            # the following kwargs are used to build a discretization of the circle containing 6 equally distributed points
            type='regular', N=6,   
        )
        
        # map with an equivariant Linear layer to the input expected by the activation function, apply batchnorm and finally the activation
        self.block1 = nn.SequentialModule(
            nn.Linear(self.in_type, activation1.in_type),
            nn.IIDBatchNorm1d(activation1.in_type),
            activation1,
        )

The second layer is defined with an output representation band-limited to 3. As there is no irrep (2,) and (3,) in the output representation of the first layer, these fields will always be set to zero by the second nn.Linear layer. Another information bottleneck.

These bottlenecks set values to zero, as it should be done to respect equivariance, but the user might overlook it if no warning is provided. Mainly because the non-linearities and bias values will modify the zero values from the nn.Linear module effectively creates fictitious inputs (signals with no relation to the input of the MLP) for the following layers.

I think it's relevant to set a warning on the basis of managers raising the information bottlenecks present in the architecture design.

Most likely for the SO2 example, we should take the input points/vectors in 2D and compute its component in the band-limited irreps of SO2 before passing it as the input to the network. One could theoretically argue that a point at (r,θ) can be thought of as a "delta function" at angle θ with amplitude r. The Fourier transform of a delta function is a constant, implying that the point is equally present in all irreps. I would then "copy" the input point and assign it to all band-limited irreps we choose to model SO2 with.

PS: I double-checked that the outputs of the layers were in fact, being set to zero as expected by Shur's lemma.

PPS: Since these bottlenecks are not tracked, the number of trainable parameters of the final architecture will be grossly overestimated. For instance, if my first layer output has half of its irreps in one of these bottlenecks, approximately half of the architecture's parameters won't be actually learning anything meaningful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant