Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

import_jax_weights_ failed on AlphaFold-Multimer 2.3.0 #121

Open
dingquanyu opened this issue May 15, 2023 · 3 comments
Open

import_jax_weights_ failed on AlphaFold-Multimer 2.3.0 #121

dingquanyu opened this issue May 15, 2023 · 3 comments

Comments

@dingquanyu
Copy link

Hi,

Thanks for the conversion script. It worked on AlphaFold Multimer 2.2.0 but when I run the same command on AlphaFold 2.3.0

I guess that's because AlphaFold changed the layout of outputs in 2.3.0? Do you have plans to update the convert script later on?

Traceback (most recent call last):
  File "convert_alphafold_to_unifold.py", line 17, in <module>
    import_jax_weights_(model, load_ckpt, version=model_name)
  File "/g/kosinski/geoffrey/unifold_review/unifold_classification/Uni-Fold/scripts/translate_jax_params.py", line 532, in import_jax_weights_
    assign(flat, data)
  File "/g/kosinski/geoffrey/unifold_review/unifold_classification/Uni-Fold/scripts/translate_jax_params.py", line 113, in assign
    weights = torch.as_tensor(orig_weights[k])
  File "/home/dingquanyu/.conda/envs/unifold_3.8/lib/python3.8/site-packages/numpy/lib/npyio.py", line 260, in __getitem__
    raise KeyError("%s is not a file in the archive" % key)
KeyError: 'alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_embedding_iteration/triangle_multiplication_outgoing/layer_norm_input//scale is not a file in the archive'

I have attached the full error report as well: here and look forward to your advice. Thank you.

@ZiyaoLi
Copy link
Member

ZiyaoLi commented May 16, 2023

See if branch #122 fix this. Plz use model names as multimer_af2_v3 / multimer_af2_model45_v3 when running.

Let me know if everything goes on smooth for you. If so I'll merge this.

@dingquanyu
Copy link
Author

Hi @ZiyaoLi

I've tested the new scripts but got an error: TypeError: init() got an unexpected keyword argument 'swap'

Traceback (most recent call last):
  File "scripts/convert_alphafold_to_unifold.py", line 17, in <module>
    import_jax_weights_(model, load_ckpt, version=model_name)
  File "/g/kosinski/geoffrey/unifold_review/Uni-Fold/scripts/translate_jax_params.py", line 431, in import_jax_weights_
    [TemplatePairBlockParams(b) for b in tps_blocks]
  File "/g/kosinski/geoffrey/unifold_review/Uni-Fold/scripts/translate_jax_params.py", line 431, in <listcomp>
    [TemplatePairBlockParams(b) for b in tps_blocks]
  File "/g/kosinski/geoffrey/unifold_review/Uni-Fold/scripts/translate_jax_params.py", line 339, in <lambda>
    "triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
  File "/g/kosinski/geoffrey/unifold_review/Uni-Fold/scripts/translate_jax_params.py", line 244, in <lambda>
    "projection": LinearSwapParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
  File "/g/kosinski/geoffrey/unifold_review/Uni-Fold/scripts/translate_jax_params.py", line 174, in <lambda>
    "weights": LinearWeightSwap(l.weight),
  File "/g/kosinski/geoffrey/unifold_review/Uni-Fold/scripts/translate_jax_params.py", line 150, in <lambda>
    LinearWeightSwap = lambda l: (Param(l, param_type=ParamType.LinearWeight, swap=True))
TypeError: __init__() got an unexpected keyword argument 'swap'

@dingquanyu
Copy link
Author

Hi again,

I've fixed this error and it works now. I've also pushed it to a new PR but I'm not sure if my solution is correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants