Skip to content

FullyConnectedTensorProduct Feature Discrepancy in JAX vs. Torch #44 #45

Answered by mariogeiger
tai-dang11 asked this question in Q&A
Discussion options

You must be logged in to vote

Here is an example on how to do that in e3nn-jax

def model(x):
    y = e3nn.tensor_product(x, x)
    lin = e3nn.FunctionalLinear(y.irreps, "0e + 1o")
    w = e3nn.haiku.MultiLayerPerceptron(
        [64, 64, lin.num_weights], jax.nn.silu, output_activation=False
    )(y.filter(keep="0e").array)
    return lin(w, y)


model = hk.without_apply_rng(hk.transform(model))
x = e3nn.normal("2x0e + 3x1o + 3x1e", jax.random.PRNGKey(0))

params = model.init(jax.random.PRNGKey(1), x)
y = model.apply(params, x)

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@tai-dang11
Comment options

@mariogeiger
Comment options

@tai-dang11
Comment options

Answer selected by mariogeiger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants