-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a transformer neural operator model and an accompanying example for training it on Darcy flow #293
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @zijieli-Jlee! This looks awesome. A few minor things:
- Thanks for fixing AttentionKernelLayer fails test due to error in Einsum #290 ! Would it be OK to split up this PR into one that fixes the issue (which we can merge immediately) and one that implements the TransformerNO (which may take longer to go over)?
- It would be great to add unit tests for the
TransformerNO
model, along the lines ofneuralop/models/tests/test_fno.py
(just verifying that the forward pass produces outputs that we would expect) - It might help keep the code flexible and easier to maintain if the
EncoderBlocks
was its own module inneuralop.layers
, along the lines of theFNOBlocks
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the updates @zijieli-Jlee, the PR is in really good shape overall. I left a few small comments
model = get_model(config) | ||
|
||
|
||
class ModelWrapper(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the function of the ModelWrapper class? It seems like it adds functionality to the Transformer NO's forward pass. Some documentation would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to incorporate the things in ModelWrapper
into the dataset instead? I think this has been the convention we have been taking in the library. @JeanKossaifi Feel free to correct me if I'm wrong.
neuralop/models/transformer_no.py
Outdated
self.enc_pos_emb_module = None | ||
self.dec_pos_emb_module = None | ||
|
||
self.encoder = nn.ModuleList( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: FNOBlocks
takes n_layers
as a parameter instead of exposing a moduleList within the FNO
itself. It might be nice to have a similar convention here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it;)
Hi @dhpitt , thanks for the comments!
|
@zijieli-Jlee thanks for the edits, your PR is in good shape. if you fix the conflicts we can re-run the tests and see if it's ready to go. |
Thanks for the fixes @zijieli-Jlee ! I fixed a one-line conflict with a model import since the last update to main. This PR is looking great overall, but I think your layers and util functions could use clearer documentation in the docstrings - see |
I added docstring to the transformer_block and some newly added positional encoding class. In addition I just spotted and fixed a bug in the normalization inside attention layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is looking great. My only comment is a small nitpick about putting the model source within docstrings. Otherwise I think this is ready to go.
neuralop/layers/transformer_block.py
Outdated
pos_src: torch.Tensor, grid point coordinates of shape [batch_size, num_src_grid_points, channels] | ||
pos_emb_module: nn.Module, positional embedding module, by default None | ||
pos_qry: torch.Tensor, grid point coordinates of shape [batch_size, num_sry_grid_points, channels], | ||
by default None and is set to pos_src |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful to include the specific purpose of pos_qry
as compared to pos_src - inputs to the query basis function
neuralop/layers/embeddings.py
Outdated
# Gaussian random Fourier features | ||
# code modified from: https://github.com/ndahlquist/pytorch-fourier-feature-networks | ||
class GaussianFourierFeatureTransform(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these comments probably belong in the docstring
neuralop/layers/embeddings.py
Outdated
|
||
# SirenNet | ||
# code modified from: https://github.com/lucidrains/siren-pytorch/blob/master/siren_pytorch/siren_pytorch.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above - cleaner to leave these comments in the docstring
Looks great! Asking @JeanKossaifi for final approval. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a couple of comments only because it mostly looks good to me!
pos_src=pos, | ||
positional_embedding_module=pos_emb_module, | ||
**kwargs) | ||
u = u + u_attention_skip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you noticed any empirical improvements from having this skip connection be a pointwise linear (not MLP) as opposed to identity connection? If so, it could be interesting to supplement with this functionality.
model = get_model(config) | ||
|
||
|
||
class ModelWrapper(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to incorporate the things in ModelWrapper
into the dataset instead? I think this has been the convention we have been taking in the library. @JeanKossaifi Feel free to correct me if I'm wrong.
neuralop/layers/embeddings.py
Outdated
dim_in: int, Number of input channels. | ||
dim_out: int, Number of output channels. | ||
w0: float, scaling factor (denominator) used to initialize the weights, by default 6. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
w0
seems to default to 1
instead of 6
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is also used to initialize the Sine
activation scaling, is this intentional?
neuralop/layers/embeddings.py
Outdated
weight = torch.zeros(dim_out, dim_in) | ||
bias = torch.zeros(dim_out) if use_bias else None | ||
self.init_(weight, bias, c=c, w0=w0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason that we do not use torch.nn.Linear
and torch.nn.init.uniform_
to initialize?
neuralop/models/base_model.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These edits could be removed from the PR
neuralop/layers/embeddings.py
Outdated
|
||
|
||
class SirenNet(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dhpitt Should we move this to layers/mlp.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense to me
neuralop/layers/embeddings.py
Outdated
|
||
|
||
class GaussianFourierFeatureTransform(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dhpitt We now have two Fourier feature embeddings, i.e., PositionalEmbedding
above and this one --- I think we should name them more consistently, e.g., GaussianFourierEmbedding
and FourierEmbedding
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea @zijieli-Jlee
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all looks great. @JeanKossaifi
Main changes:
transformer_no.py
undermodels
, which is a Transformer encoder-decoder architecture for learning operator. Small changes to the__init__.py
andmodel_dispatcher.py
accordingly.embeddings.py
which is useful for building the neural field (query point embedding) in the Transformer decoder.train_darcy_transformer.py
underscripts
, the corresponding config yaml file isdarcy_transformer_config.yaml
.