Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issues in AntiBERTyRunner.py #1

Open
Elmiar0642 opened this issue Aug 4, 2023 · 0 comments
Open

issues in AntiBERTyRunner.py #1

Elmiar0642 opened this issue Aug 4, 2023 · 0 comments

Comments

@Elmiar0642
Copy link

Hey there,

I attempted to re-run the new v3.0.x of IGFold with openmm on my system last night. After updating and upgrading the packages, I tried to run the notebook, and I found the following error being thrown from the script AntiBERTy.py.

File "/xxx/yyy/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

To resolve this, I checked which devices the variables embeddings and attention_maps are attached and detached.

They both were created in GPU and only embeddings is detached from the GPU to the CPU. So, I made the following change:

  • Detached them to the CPU, and made both into a list.

`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)
embeddings = embeddings.detach().cpu().tolist()

    for i, a in enumerate(attention_mask.detach().cpu().tolist()):
        embeddings[i] = embeddings[i][:, a == 1]

`

It threw the following error:

File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
TypeError: list indices must be integers or slices, not tuple

To understand the core problem, I wanted to understand embeddings and attention_maps. So,

`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)

    print(f"embeddings: {embeddings} length:{embeddings.shape}")
    print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
    
    embeddings = embeddings.detach().cpu().tolist()
    
    for i, a in enumerate(attention_mask.detach().cpu().tolist()):
        embeddings[i] = embeddings[i][:, a == 1]

`

Details

embeddings: tensor([[[[ 2.8550e-01, -9.1305e-01, -3.2084e-01,  ..., -2.6094e-01,
           -1.1316e-01,  6.2532e-02],
          [ 1.2959e-01, -2.3578e-01, -9.5074e-01,  ..., -3.4716e-01,
            3.9048e-01, -7.9039e-01],
          [-1.1861e-01, -8.5111e-01,  1.7778e-01,  ..., -6.4417e-01,
           -1.6268e-01, -7.4019e-01],
          ...,
          [ 1.4825e+00,  1.0562e+00, -5.5296e-01,  ...,  4.6048e-02,
           -5.8749e-01,  3.5935e-01],
          [ 1.1087e+00,  8.3452e-01, -4.6560e-01,  ..., -6.5979e-01,
            7.0711e-02,  1.3638e+00],
          [ 7.1583e-01,  8.4463e-01,  7.4550e-01,  ...,  5.5646e-01,
           -6.0864e-01,  1.2408e+00]],

         [[ 8.2428e-01, -6.0705e-01, -9.0634e-01,  ..., -4.5286e-02,
           -6.8834e-02,  4.4105e-01],
          [ 7.2001e-01,  6.3411e-01, -1.0107e+00,  ..., -4.3047e-01,
           -5.7251e-01, -6.7011e-01],
          [ 4.6859e-01, -8.5742e-01, -1.5053e-02,  ..., -2.8734e-01,
           -1.0233e+00, -3.6219e-01],
          ...,
          [ 1.0764e+00,  1.1695e+00, -6.8277e-01,  ...,  2.8122e-02,
           -9.8832e-01,  1.4659e-01],
          [ 8.8104e-01,  1.1147e+00, -7.1646e-01,  ..., -1.0783e-01,
           -7.9473e-01,  1.0538e+00],
          [ 6.3558e-01,  9.0190e-01,  4.0055e-01,  ...,  3.1800e-01,
           -1.0868e+00,  9.7025e-01]],

         [[ 9.6156e-01, -9.6647e-01, -1.4004e+00,  ..., -6.3557e-01,
            4.1958e-01, -1.8568e-01],
          [ 3.0844e-01,  1.0339e+00, -1.5486e+00,  ...,  2.1584e-01,
           -3.8619e-01, -8.9405e-01],
          [ 4.5382e-01, -3.8623e-01,  1.7961e-01,  ..., -1.4155e-01,
           -1.1880e+00, -5.4827e-01],
          ...,
          [ 9.9114e-01,  5.7983e-01, -2.9399e-01,  ..., -4.6010e-01,
           -6.7488e-01, -6.2466e-01],
          [ 7.5153e-01,  4.8691e-01, -5.4032e-01,  ...,  2.6127e-01,
           -1.0607e+00,  7.8277e-01],
          [ 8.5168e-01,  4.9293e-01, -2.6708e-01,  ...,  3.8526e-01,
           -1.1824e+00,  8.5203e-01]],

         ...,

         [[ 1.2814e+00, -4.3900e-01, -3.2785e-01,  ..., -1.2414e+00,
           -6.3775e-01, -1.3176e+00],
          [ 3.0157e-01,  1.6172e+00, -1.3343e+00,  ..., -1.2285e+00,
           -5.5167e-01, -1.8283e+00],
          [ 3.5919e-01, -2.6482e-01, -1.0645e+00,  ..., -4.3375e-02,
           -3.2065e-01, -9.8966e-01],
          ...,
          [ 1.8181e+00, -1.6646e-01, -1.2666e+00,  ...,  1.0637e+00,
            1.4646e+00, -1.6298e+00],
          [ 1.0763e+00, -5.1882e-01, -6.8510e-01,  ...,  1.3576e+00,
            1.2688e+00, -1.4657e+00],
          [ 1.7986e+00, -7.4009e-02, -1.2577e+00,  ...,  1.0660e+00,
            1.4812e+00, -1.4051e+00]],

         [[ 1.2025e+00, -5.5392e-01, -1.0193e+00,  ..., -8.1229e-01,
           -2.3811e-01, -4.7275e-01],
          [ 6.5538e-01,  1.1917e+00, -5.2697e-01,  ..., -8.7801e-01,
           -7.4126e-01, -1.9144e+00],
          [ 2.5875e-01, -7.9232e-01, -8.5029e-01,  ...,  6.4324e-02,
           -8.0997e-02, -1.9687e+00],
          ...,
          [ 1.4830e+00, -1.9244e-01, -6.8066e-01,  ...,  2.1269e-01,
            1.0873e+00, -1.3896e+00],
          [ 5.3997e-01, -1.4820e-01, -2.0483e-01,  ...,  7.3495e-01,
            8.6871e-01, -1.3526e+00],
          [ 1.6477e+00, -5.3092e-02, -7.1276e-01,  ...,  3.2879e-01,
            1.1778e+00, -9.6469e-01]],

         [[ 1.5494e+00, -9.5254e-01, -8.3588e-01,  ..., -4.2762e-01,
            6.2013e-01,  1.0120e-02],
          [ 4.4904e-02,  7.8505e-01, -1.0384e+00,  ..., -7.8334e-02,
           -1.7476e-01, -1.6311e+00],
          [ 1.7894e-01, -9.9010e-01, -1.1633e+00,  ...,  6.0122e-01,
           -1.0615e-01, -1.5358e+00],
          ...,
          [ 1.2771e+00, -1.8352e-01, -1.4466e+00,  ..., -6.2605e-01,
            1.2011e+00, -2.0856e+00],
          [ 5.6284e-01, -9.5801e-02, -1.1209e+00,  ..., -5.1828e-01,
            4.9442e-01, -1.5956e+00],
          [ 1.1071e+00,  3.0336e-01, -1.8048e+00,  ..., -3.8724e-01,
            1.1147e+00, -1.5361e+00]]],


        [[[ 2.8550e-01, -9.1305e-01, -3.2084e-01,  ..., -2.6094e-01,
           -1.1316e-01,  6.2532e-02],
          [ 5.0035e-01,  5.4549e-01,  3.4283e-01,  ..., -3.0739e-01,
           -4.9315e-01, -1.1373e+00],
          [-4.0275e-01,  2.1443e-02,  2.0123e-01,  ..., -2.4489e-01,
            8.3188e-01, -6.5645e-01],
          ...,
          [ 4.0514e-01, -3.2213e-01,  3.7994e-01,  ...,  1.2408e-01,
            6.3095e-01,  9.2037e-03],
          [ 1.9132e-01, -4.4131e-01,  4.2406e-01,  ..., -2.6266e-01,
            9.8391e-01,  5.5734e-01],
          [ 4.0278e-01, -4.9534e-02,  3.3810e-01,  ...,  1.4354e-01,
            8.4249e-01,  4.0723e-01]],

         [[ 6.2418e-02, -6.1317e-01, -1.5439e+00,  ..., -3.1803e-01,
           -2.0041e-01,  4.4618e-01],
          [-6.7039e-02,  1.2193e+00, -5.0822e-01,  ...,  3.5469e-01,
            2.6262e-02, -7.7125e-01],
          [-9.5805e-01,  1.4456e-01, -1.8127e-01,  ...,  3.6328e-01,
            1.4936e+00, -4.5747e-02],
          ...,
          [ 6.8287e-02,  8.2539e-01,  5.4192e-02,  ..., -1.1069e-01,
            6.6216e-01,  7.4946e-01],
          [-1.9581e-01,  6.8329e-01, -2.6928e-01,  ..., -7.0956e-01,
            7.8344e-01,  1.4804e+00],
          [-4.1462e-02,  8.8683e-01, -5.2905e-01,  ..., -2.5274e-01,
            7.1604e-01,  1.2256e+00]],

         [[ 3.5130e-01, -1.5874e+00, -1.7016e+00,  ...,  6.8850e-01,
           -5.8646e-01,  1.7784e-01],
          [ 1.1386e-01,  1.3657e+00, -8.2388e-01,  ...,  4.7490e-01,
            1.2626e+00, -3.1313e-01],
          [-1.1854e+00, -1.1600e-03, -7.3433e-01,  ...,  7.6139e-01,
            1.6375e+00,  1.8955e-01],
          ...,
          [-6.9969e-01,  1.1508e+00,  7.0558e-02,  ...,  4.2873e-01,
            5.6067e-01,  5.2250e-01],
          [-5.0788e-01,  6.6331e-01, -6.1032e-01,  ..., -2.3532e-01,
            8.2221e-01,  7.9204e-01],
          [-2.6820e-01,  8.5643e-01, -4.7090e-01,  ..., -2.8118e-01,
            6.5296e-01,  6.8785e-01]],

         ...,

         [[-9.0217e-02, -2.6741e-01, -1.0890e+00,  ...,  1.8798e+00,
           -3.2522e-03, -1.5653e-01],
          [-6.9740e-01,  1.4951e+00, -6.4886e-01,  ..., -1.3687e-01,
            1.4956e+00,  3.7487e-01],
          [-1.6580e-01,  1.1264e-01, -7.6442e-01,  ...,  4.3402e-01,
            1.9541e+00,  1.2029e+00],
          ...,
          [ 1.9953e-01,  2.6025e+00, -4.9651e-01,  ...,  5.0344e-01,
           -1.2114e-02,  3.9688e-01],
          [-1.0917e+00,  1.2115e+00,  6.2053e-01,  ...,  8.5435e-01,
           -4.5358e-02,  3.5120e-01],
          [ 6.1694e-01,  2.1130e+00, -1.1016e+00,  ...,  2.8187e-01,
            9.5419e-02, -3.5959e-01]],

         [[ 5.0400e-01, -5.3220e-01, -1.0173e+00,  ...,  2.1676e+00,
           -3.6843e-01, -1.8500e-01],
          [-2.1364e-01,  9.2027e-01, -2.5382e-01,  ...,  1.1757e-01,
            9.4363e-01,  6.0816e-01],
          [-1.0163e-01, -3.2413e-02, -7.2567e-01,  ...,  1.1070e+00,
            1.3306e+00,  1.0462e+00],
          ...,
          [ 3.0022e-01,  2.6991e+00, -4.7573e-01,  ..., -1.0428e-01,
           -7.8721e-02,  1.1695e+00],
          [-1.0961e+00,  6.7808e-01, -3.0792e-01,  ...,  7.1660e-01,
           -2.0900e-01,  4.5738e-01],
          [ 8.4948e-01,  1.9340e+00, -1.1624e+00,  ..., -2.2008e-01,
            4.5761e-01,  9.6474e-01]],

         [[-2.9150e-01,  4.8298e-01, -3.7572e-01,  ...,  2.4827e+00,
           -1.9686e-01,  2.9108e-01],
          [-4.5003e-01,  4.0321e-01, -1.0218e+00,  ..., -1.9378e-01,
            5.3391e-01,  3.8499e-01],
          [-7.7064e-02, -5.0206e-01, -1.3377e+00,  ...,  9.1953e-01,
            5.2488e-01,  1.2372e-01],
          ...,
          [ 7.2962e-01,  1.8133e+00,  2.9414e-01,  ...,  7.3038e-01,
           -2.0271e-01,  2.1481e+00],
          [-7.7066e-01, -1.0586e-01, -9.3787e-02,  ...,  1.0239e+00,
           -2.1658e-01,  9.3203e-01],
          [ 1.0556e+00,  9.7592e-01, -1.2148e+00,  ..., -4.7689e-02,
           -1.4709e-02,  2.9145e-01]]]], device='cuda:0') length:torch.Size([2, 9, 120, 512])
attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) length:torch.Size([2, 120])

I made this change:
`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)

    print(f"embeddings: {embeddings} length:{embeddings.shape}")
    print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
    embeddings = embeddings.detach().cpu()
    
    for i, a in enumerate(attention_mask.detach().cpu()):
        embeddings[i] = embeddings[i][:, a == 1]

`

So, finally, I tried to replace them as tensors and tried to replace, but it obviously threw tensor dimensions mismatch error:

File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: The expanded size of the tensor (120) must match the existing size (109) at non-singleton dimension 1.  Target sizes: [9, 120, 512].  Tensor sizes: [9, 109, 512]

Because embeddings size is: [2, 9, 120, 512].
Whereas attention_mask size is: [2, 120].

What is the end goal of the following snippet? Why does this throw an error? Please help me resolv this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant