Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

e3nn tensors compatibility issue #9

Open
liyy2 opened this issue Mar 7, 2024 · 3 comments
Open

e3nn tensors compatibility issue #9

liyy2 opened this issue Mar 7, 2024 · 3 comments

Comments

@liyy2
Copy link

liyy2 commented Mar 7, 2024

Hi, I am trying to integrate this with the e3nn package.

For the SO3Embedding class, how can I convert that to an irrep which is compatible with the convention e3nn?
My implementation (not sure this is right or not)

    def to_e3nn_embeddings(self):
        from e3nn.io import SphericalTensor
        from e3nn.o3 import Irreps
        embedding = self.embedding.reshape(self.length, -1)

        l = o3.Irreps(str(SphericalTensor(self.lmax_list[-1], 1, -1)).replace('1x', f'{self.num_channels}x'))
        # multiple channels
        return l, embedding
@yilunliao
Copy link
Member

Hi @liyy2

I am not familiar with SphericalTensor.

But for tensors in e3nn, they are typically in the form of C_0x0e+C_1x1e... (e.g., 128x0e+128x1e+...).
(Let me know if the above one is not clear.)

For EquiformerV2, the tensors are in the form of (0e+1e+..., C) and have shape ((1+L_{max})**2, C).
We require the number of channels for each degree to be the same here.
(Let me know if that is not clear)

So to convert between these two formats, we can extract all the channels for each degree, flatten them and concatenate all the flattened tensors.
Here is an example of converting e3nn tensors to tensors in EquiformerV2:

lmax = 2
num_channels = 128
irreps = o3.Irreps('128x0e+128x1e+128x2e')
tensor_e3nn = irreps.randn(1, -1)  # shape: (1, 128 * (1 + 2) ** 2)

out = []
start_idx = 0
for l in range(lmax + 1):
    length = (2 * l + 1) * num_channels
    feature = tensor_e3nn.narrow(1, start_idx, length)  # extract all the channels corresponding to degree l
    feature = feature.view(-1, num_channels, (2 * l + 1))
    feature = feature.transpose(1, 2).contiguous()
    out.append(feature)
    start_idx = start_idx + length
tensor_equiformer_v2 = torch.cat(out, dim=1)

You can follow the above example to do the reverse.

@liyy2
Copy link
Author

liyy2 commented Mar 27, 2024

hi, thank you for the detailed response. My question is does parity impact the model here? Should i use o3.Irreps('128x0e+128x1e+128x2e') or o3.Irreps('128x0e+128x1o+128x2e')

@yilunliao
Copy link
Member

For EquiformerV2, we currently use SE(3), and therefore, we should use '128x0e+128x1e+128x2e'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants