Skip to content

Releases: e3nn/e3nn-jax

2024-01-26

26 Jan 08:16
Compare
Choose a tag to compare

Added

  • e3nn.where function
  • Add optional mask argument in e3nn.flax.BatchNorm

Changed

  • replace jnp.ndarray by jax.Array

2024-01-05

05 Jan 11:51
Compare
Choose a tag to compare

Added

  • e3nn.ones and e3nn.ones_like functions
  • e3nn.equinox submodule

Fixed

  • python 3.9 compatibility

Thanks to @ameya98, @SauravMaheshkar and @pabloferz

2023-12-24

24 Dec 09:17
Compare
Choose a tag to compare

Fixed

  • Fix pyproject.toml, the documenation build was broken. Thanks to @SauravMaheshkar!

Added

  • Support for s2fft in e3nn.to_s2grid and e3nn.from_s2grid, thanks to @ameya98!
  • Add a special case implementation for e3nn.scatter_mean when map_back and nel is not None.

2023-11-17

17 Nov 11:39
Compare
Choose a tag to compare

Added

  • e3nn.flax.BatchNorm
  • e3nn.scatter_mean
  • Add e3nn.utils.vmap also directly to e3nn module: e3nn.vmap

2023-09-25

04 Oct 10:35
Compare
Choose a tag to compare

Added

  • with_bias argument to e3nn.haiku.MultiLayerPerceptron and e3nn.flax.MultiLayerPerceptron

Fixed

  • Improve compilation speed and stability of s2grid for large lmax (use is_normalized=True in lpmn_values)

2023-09-13

13 Sep 09:58
Compare
Choose a tag to compare

Changelog

Changed

  • Add back the optimizations with the lazy ._chunks that was removed in 0.19.0

2023-09-09

09 Sep 18:58
Compare
Choose a tag to compare

Highlight

tl;dr Mostly fix the issue #38

In version 0.19.0, I removed the lazy _list attribute of IrrepsArray to fix the issues from tree_util, grad and vmap.
In this version (0.20.0) I found a way to put back that lazy attribute, now called _chunks, in a way that does not interfere with tree_util, grad and vmap. _chunks is tropped when using tree_util, grad and vmap unless you use e3nn.vmap.

ChangeLog

Added

  • e3nn.Irreps.mul_gcd
  • e3nn.IrrepsArray.extend_with_zeros to extend an array with zeros, can be useful for residual connections

Changed

  • rewrite e3nn.tensor_square to be simpler (and faster?)
  • use jax.scipy.special.lpmn_values to implement e3nn.legendre. Faster on GPU and supports reverse-mode differentiation.
  • [BREAKING] Change the output format of e3nn.legendre!

Fixed

  • Add back a lazy ._chunks in e3nn.IrrepsArray to fix issue #38

2023-06-24

24 Jun 17:06
Compare
Choose a tag to compare

Changelog

Fixed

  • Fix missing support for zero flags in e3nn.elementwise_tensor_product

2023-06-23

23 Jun 22:05
Compare
Choose a tag to compare

By merging two jnp.einsum in one, the tensor product is faster than before (60% faster in the case I tested, see BENCHMARK.md).

Changelog

Changed

  • [BREAKING] Move Instruction, FunctionalTensorProduct and FunctionalFullyConnectedTensorProduct into e3nn.legacy submodule
  • Reimplement e3nn.tensor_product and e3nn.elementwise_tensor_product in a simpler way

2023-06-22

22 Jun 22:07
Compare
Choose a tag to compare

Highlight

e3nn.utils.vmap allow to overcome the security to drop .zero_flags for the case of vmap.

Consider an irreps array with the 0o entry set to None, its zero_flags attribute will be (False, True):

x = e3nn.from_chunks("0e + 0o", [jnp.ones((100, 1, 1)), None], (100,))
x.zero_flags  # (False, True)

Now if we vmap a function using jax.vmap, the internal function will not get the info that the 0o entry is actually zero.

jax.vmap(e3nn.norm)(x).zero_flags  # (False, False)

This is a security because not all transformations conserve the validity of zero_flags, take for instance:

jax.tree_util.tree_map(lambda x: x + 1.0, x).zero_flags  # (False, False)

However for the case of vmap, vectorization does preserves the validity of zero_flags, in this case we can allow it to propagate in-out the vectorized function:

e3nn.utils.vmap(e3nn.norm)(x).zero_flags  # (False, True)

Changelog

Added

  • e3nn.utils.vmap to propagate zero_flags in the vectorized function.

Changed

  • Simplify the tetris examples

Fixed

  • Example of what is fixed: assume x.ndim = 2, allow x[:, None] but prevent x[:, :, None] and x[..., None]