Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax support #55

Open
emilemathieu opened this issue May 22, 2023 · 9 comments
Open

Jax support #55

emilemathieu opened this issue May 22, 2023 · 9 comments

Comments

@emilemathieu
Copy link

emilemathieu commented May 22, 2023

Hi,

Thanks for this really cool library!

Being a jax user, I was wondering whether you've thought on extending it to support jax, akin to e3nn-jax?

At first glance it seems that most of the library is 'pure' python, appart from the GeometricTensor and FieldType classes in escnn/nn/*, which seems easily translatable to jax.Array etc, and obviously all the layers in escnn/nn/modules which would need to be rewritten for flax / haiku / equinox.

I'd be happy to help out with this :)

Best,
Emile

@emilemathieu
Copy link
Author

@Gabri95 ?

@Gabri95
Copy link
Collaborator

Gabri95 commented Jun 14, 2023

Hi @emilemathieu

I am sorry for the late reply, but I was on vacation in the last couple of weeks.

I am not currently planning to support other frameworks like Jax.
Adapting the library to Jax requires rewriting most of the escnn.nn and the escnn.kernels modules.
Unfortunately, I don't have enough time for that currently :(

I am not very familiar with Jax, so it is hard for me to precisely estimate how much work it would be, but I'd be very happy to provide some support if an user would like to try doing that!

Best,
Gabriele

@emilemathieu
Copy link
Author

Thanks @Gabri95!
I've started and got escnn.nn.Linear to work (as in tests from test/test_linear.py pass) building on the equinox module.
Trying to get escnn.nn.Conv2d to work, and then to be able to train a simple model based on layers.

@emilemathieu
Copy link
Author

Hey @Gabri95,
I gave a try and I can now reproduce the C8SteerableCNN on MNIST with a ~20% speed up!
There are still quite some things to enhance and modules to support as listed in the README.md, but hopefully other people would be interested in willing to help :)

@emilemathieu
Copy link
Author

@Gabri95
I'm thinking of making this available for pip install as escnn_jax akin to e3nn_jax.
Would you have any opinion on this?

@Gabri95
Copy link
Collaborator

Gabri95 commented Jul 4, 2023

hi @emilemathieu

Thanks a lot for porting the library in Jax!

The escnn_jax name seems good to me!

I need to push some updates on the library in the coming week (sorry for being a little offline recently :/ ) and I will also include a pointer to your code in the documentation.

Thanks a lot again!

Gabriele

@emilemathieu
Copy link
Author

Happy to eventually works ahah was more than expected and there are still quite some layers not supported, but should be easy to add them taking inspiration from the ones I've ported :)

  • Would you know how to make the library available for pip install escnn_jax?
  • Should I move the repo to QUVA-Lab?

@Gabri95
Copy link
Collaborator

Gabri95 commented Jul 7, 2023

This is a good question, let me ask for some advice about this and come back when I know more, sorry

Regarding pypi, you can follow these intructions (but I'd recommend trying with testpypi first)

Gabriele

@Gabri95
Copy link
Collaborator

Gabri95 commented Jul 12, 2023

Hi @emilemathieu

I've just included a reference to your fork in the README of the library!

Regarding where you should host the repo, it seems possible to both keep it where it is now and move it under QUVA-Lab.
I don't have a strong preference but I see two points to keep in mind.
Keeping it under your account will give you more visibility (and credit 😉 ), which I think is fair.
Moving it under QUVA-Lab might make it easier for someone else to take over in the future if you don't want to maintain it anymore.

Again, I don't have a strong preference, but maybe we can keep it under your account for the moment and move it to quva-lab in the future.
What do you think about this?

Thanks again for this amazing work! 😄

Best,
Gabriele

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants