Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a transformer neural operator model and an accompanying example for training it on Darcy flow #293

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

zijieli-Jlee
Copy link
Contributor

Main changes:

  • Add transformer_no.py under models, which is a Transformer encoder-decoder architecture for learning operator. Small changes to the __init__.py and model_dispatcher.py accordingly.
  • Add Gaussian Fourier feature and Siren to the embeddings.py which is useful for building the neural field (query point embedding) in the Transformer decoder.
  • Add a training example train_darcy_transformer.py under scripts, the corresponding config yaml file is darcy_transformer_config.yaml.
  • Fixed a bug in previous attention layer test #290

Copy link
Collaborator

@dhpitt dhpitt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zijieli-Jlee! This looks awesome. A few minor things:

  1. Thanks for fixing AttentionKernelLayer fails test due to error in Einsum #290 ! Would it be OK to split up this PR into one that fixes the issue (which we can merge immediately) and one that implements the TransformerNO (which may take longer to go over)?
  2. It would be great to add unit tests for the TransformerNO model, along the lines of neuralop/models/tests/test_fno.py (just verifying that the forward pass produces outputs that we would expect)
  3. It might help keep the code flexible and easier to maintain if the EncoderBlocks was its own module in neuralop.layers, along the lines of the FNOBlocks.

Copy link
Collaborator

@dhpitt dhpitt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the updates @zijieli-Jlee, the PR is in really good shape overall. I left a few small comments

model = get_model(config)


class ModelWrapper(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the function of the ModelWrapper class? It seems like it adds functionality to the Transformer NO's forward pass. Some documentation would be helpful.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to incorporate the things in ModelWrapper into the dataset instead? I think this has been the convention we have been taking in the library. @JeanKossaifi Feel free to correct me if I'm wrong.

self.enc_pos_emb_module = None
self.dec_pos_emb_module = None

self.encoder = nn.ModuleList(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: FNOBlocks takes n_layers as a parameter instead of exposing a moduleList within the FNO itself. It might be nice to have a similar convention here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it;)

@zijieli-Jlee
Copy link
Contributor Author

Hi @dhpitt , thanks for the comments!

  1. I saw this was fixed at copy zijieli-Jlee's fix for attention kernel layer test #295 , let me know if there is still any remaining issue
  2. Definitely. I will add the test for the TransformerNO if its current structure looks fine (see the next point)
  3. Thanks for the suggestion! I added transformer_block following fno_block to layers, the transformer_block contains implementation for TransformerEncoderBlock and TransformerDecoderBlock. I also added test for these modules to test_transformer_block

@dhpitt
Copy link
Collaborator

dhpitt commented Mar 7, 2024

@zijieli-Jlee thanks for the edits, your PR is in good shape. if you fix the conflicts we can re-run the tests and see if it's ready to go.

@dhpitt
Copy link
Collaborator

dhpitt commented Mar 26, 2024

Thanks for the fixes @zijieli-Jlee ! I fixed a one-line conflict with a model import since the last update to main.

This PR is looking great overall, but I think your layers and util functions could use clearer documentation in the docstrings - see neuralop/layers/integral_transform.py for an example of informative docs.

@zijieli-Jlee
Copy link
Contributor Author

I added docstring to the transformer_block and some newly added positional encoding class. In addition I just spotted and fixed a bug in the normalization inside attention layer

Copy link
Collaborator

@dhpitt dhpitt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is looking great. My only comment is a small nitpick about putting the model source within docstrings. Otherwise I think this is ready to go.

pos_src: torch.Tensor, grid point coordinates of shape [batch_size, num_src_grid_points, channels]
pos_emb_module: nn.Module, positional embedding module, by default None
pos_qry: torch.Tensor, grid point coordinates of shape [batch_size, num_sry_grid_points, channels],
by default None and is set to pos_src
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to include the specific purpose of pos_qry as compared to pos_src - inputs to the query basis function

# Gaussian random Fourier features
# code modified from: https://github.com/ndahlquist/pytorch-fourier-feature-networks
class GaussianFourierFeatureTransform(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these comments probably belong in the docstring


# SirenNet
# code modified from: https://github.com/lucidrains/siren-pytorch/blob/master/siren_pytorch/siren_pytorch.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above - cleaner to leave these comments in the docstring

@dhpitt
Copy link
Collaborator

dhpitt commented Apr 2, 2024

Looks great! Asking @JeanKossaifi for final approval.

Copy link
Member

@mliuschi mliuschi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of comments only because it mostly looks good to me!

pos_src=pos,
positional_embedding_module=pos_emb_module,
**kwargs)
u = u + u_attention_skip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you noticed any empirical improvements from having this skip connection be a pointwise linear (not MLP) as opposed to identity connection? If so, it could be interesting to supplement with this functionality.

model = get_model(config)


class ModelWrapper(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to incorporate the things in ModelWrapper into the dataset instead? I think this has been the convention we have been taking in the library. @JeanKossaifi Feel free to correct me if I'm wrong.

dim_in: int, Number of input channels.
dim_out: int, Number of output channels.
w0: float, scaling factor (denominator) used to initialize the weights, by default 6.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

w0 seems to default to 1 instead of 6?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also used to initialize the Sine activation scaling, is this intentional?

weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c=c, w0=w0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason that we do not use torch.nn.Linear and torch.nn.init.uniform_ to initialize?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These edits could be removed from the PR



class SirenNet(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhpitt Should we move this to layers/mlp.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me



class GaussianFourierFeatureTransform(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhpitt We now have two Fourier feature embeddings, i.e., PositionalEmbedding above and this one --- I think we should name them more consistently, e.g., GaussianFourierEmbedding and FourierEmbedding?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea @zijieli-Jlee

Copy link
Collaborator

@dhpitt dhpitt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks great. @JeanKossaifi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants