diff --git a/jax_md/_nn/nequip.py b/jax_md/_nn/nequip.py index d9dc3da8..e233b9a8 100644 --- a/jax_md/_nn/nequip.py +++ b/jax_md/_nn/nequip.py @@ -192,7 +192,7 @@ def __call__( # we gather the instructions for the tp as well as the tp output irreps mode = 'uvu' - trainable = 'True' + trainable = True irreps_after_tp = [] instructions = []