Releases: e3nn/e3nn-jax
2024-01-26
Added
e3nn.where
function- Add optional
mask
argument ine3nn.flax.BatchNorm
Changed
- replace
jnp.ndarray
byjax.Array
2024-01-05
Added
e3nn.ones
ande3nn.ones_like
functionse3nn.equinox
submodule
Fixed
- python 3.9 compatibility
Thanks to @ameya98, @SauravMaheshkar and @pabloferz
2023-12-24
2023-11-17
Added
e3nn.flax.BatchNorm
e3nn.scatter_mean
- Add
e3nn.utils.vmap
also directly toe3nn
module:e3nn.vmap
2023-09-25
Added
with_bias
argument toe3nn.haiku.MultiLayerPerceptron
ande3nn.flax.MultiLayerPerceptron
Fixed
- Improve compilation speed and stability of
s2grid
for largelmax
(useis_normalized=True
inlpmn_values
)
2023-09-13
Changelog
Changed
- Add back the optimizations with the lazy
._chunks
that was removed in 0.19.0
2023-09-09
Highlight
tl;dr Mostly fix the issue #38
In version 0.19.0
, I removed the lazy _list
attribute of IrrepsArray
to fix the issues from tree_util
, grad
and vmap
.
In this version (0.20.0
) I found a way to put back that lazy attribute, now called _chunks
, in a way that does not interfere with tree_util
, grad
and vmap
. _chunks
is tropped when using tree_util
, grad
and vmap
unless you use e3nn.vmap
.
ChangeLog
Added
e3nn.Irreps.mul_gcd
e3nn.IrrepsArray.extend_with_zeros
to extend an array with zeros, can be useful for residual connections
Changed
- rewrite
e3nn.tensor_square
to be simpler (and faster?) - use
jax.scipy.special.lpmn_values
to implemente3nn.legendre
. Faster on GPU and supports reverse-mode differentiation. - [BREAKING] Change the output format of
e3nn.legendre
!
Fixed
- Add back a lazy
._chunks
ine3nn.IrrepsArray
to fix issue #38
2023-06-24
Changelog
Fixed
- Fix missing support for zero flags in
e3nn.elementwise_tensor_product
2023-06-23
By merging two jnp.einsum
in one, the tensor product is faster than before (60% faster in the case I tested, see BENCHMARK.md
).
Changelog
Changed
- [BREAKING] Move
Instruction
,FunctionalTensorProduct
andFunctionalFullyConnectedTensorProduct
intoe3nn.legacy
submodule - Reimplement
e3nn.tensor_product
ande3nn.elementwise_tensor_product
in a simpler way
2023-06-22
Highlight
e3nn.utils.vmap
allow to overcome the security to drop .zero_flags
for the case of vmap
.
Consider an irreps array with the 0o entry set to None, its zero_flags
attribute will be (False, True)
:
x = e3nn.from_chunks("0e + 0o", [jnp.ones((100, 1, 1)), None], (100,))
x.zero_flags # (False, True)
Now if we vmap
a function using jax.vmap
, the internal function will not get the info that the 0o entry is actually zero.
jax.vmap(e3nn.norm)(x).zero_flags # (False, False)
This is a security because not all transformations conserve the validity of zero_flags
, take for instance:
jax.tree_util.tree_map(lambda x: x + 1.0, x).zero_flags # (False, False)
However for the case of vmap
, vectorization does preserves the validity of zero_flags
, in this case we can allow it to propagate in-out the vectorized function:
e3nn.utils.vmap(e3nn.norm)(x).zero_flags # (False, True)
Changelog
Added
e3nn.utils.vmap
to propagatezero_flags
in the vectorized function.
Changed
- Simplify the tetris examples
Fixed
- Example of what is fixed: assume
x.ndim = 2
, allowx[:, None]
but preventx[:, :, None]
andx[..., None]