Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize the Fourier transform API #86

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

kalekundert
Copy link
Contributor

This PR is a proposal to refactor the Fourier transform API, with the goal of making it easier to incorporate Fourier transforms into other modules. Here are the two specific use-cases I was trying to facilitate:

  • A Fourier max pooling layer, as discussed in Question about pooling with Fourier representations in R^3 #65. This would be very similar to the existing FourierPointwise class, except after the nonlinearity, there would also be max-pooling and Gaussian blurring steps.

  • An IFT as an output layer. It's probably not clear what I mean by that, and it's possible that I only needed such a layer because I overlooked some easier way of doing things, so I want to take some time to explain the problem I was trying to solve. My goal was to reimplement [Doersch2016], but in 3D and with equivariance. The idea in [Doersch2016] is to create a self-supervised training protocol by taking two nearby crops of an image, and having the model predict the location of the second relative to the first. There would only be a handful of possible relative locations, e.g. above, below, right, and left (for 2D images). I implemented this by having the final layer of my model be a single spectral regular representation (of the quotient space $S^2 = SO(3) / SO(2)$, because the two crops cannot rotate relative to each other), then performing an IFT with each grid point corresponding to one of the possible relative locations. This results in values for each location that can be interpreted as logits. And if the input rotates, so do the logits. To bring this back to the PR at hand, the important point is that this application requires being able to perform an IFT without a subsequent FT.

I think that the best way to support these two use-cases, and possibly others that I haven't thought of, is to create separate FT and IFT modules. That's what the proposed API does. Here are the specific classes involved:

  • InverseFourierTransform: A pytorch module where the input is a tensor with a spectral regular representation, and the output is a tensor of signal values sampled on a grid.

  • FourierTransform: The opposite of InverseFourierTransform. This module also provides the option to prepare the FT matrix with more irreps than will ultimately be output.

  • FourierFieldType: Most equivariant modules accept input/output field types as arguments, but FourierPointwise is an exception. It accepts gspace, channels, and irreps arguments, and uses them to create a compatible field type under the hood. This API is a bit awkward to begin with, but it's worse when the same arguments need to be passed to two different modules.

    To bring the Fourier API in line with all the other modules, I created FourierFieldType. This is a subclass of FieldType that only allows spectral regular representations (possibly with respect to a quotient space). The IFT and FT modules require this field type (and check for it). Other modules are agnostic to it.

  • GridTensor: A class that wraps the output of an IFT and the input to an FT. It's similar in concept to GeometricTensor, except that instead of keeping track of the representation associated with a tensor, it keeps track of the grid. This lets the FT module check that it's compatible with the input it receives, and (for GNNs) restore the coords attribute.

Using these classes, I reimplemented the FourierPointwise class in a way that I believe to be 100% backwards-compatible. The new implementation also removes hundreds of lines of code that were duplicated between FourierPointwise and QuotientFourierPointwise. Below is a simplified FourierRelu version of this class, just to give a sense for how it works:

class FourierRelu(EquivariantModule):
    
    def __init__(self, in_type: FourierFieldType, grid: List[GroupElement]):
        super().__init__()
        self.in_type = self.out_type = in_type
        self.ift = InverseFourierTransform(self.in_type, grid)
        self.ft = FourierTransform(grid, self.out_type)
        
    def forward(self, x_hat: GeometricTensor) -> GeometricTensor:
        assert x_hat.type == self.in_type
        x: GridTensor = self.ift(x_hat)
        F.relu_(x.tensor)
        return self.ft(x)

Minor comments:

  • This PR isn't ready to be merged yet. I haven't updated the documentation, and although all the existing tests pass, I want to write some new tests as well. But before I spend a lot of time on those tasks, I want to know if there's any interest in merging this.

  • I haven't implemented the aforementioned Fourier max pooling module yet. But if there's interest, I could add that to the PR as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant