diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index e0631548..f5cb3bf4 100644 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -31,13 +31,12 @@ from openfold.config import model_config from openfold.data import templates, feature_pipeline, data_pipeline from openfold.model.model import AlphaFold -from openfold.model.primitives import Attention, GlobalAttention +from openfold.model.torchscript import script_primitives_ from openfold.np import residue_constants, protein import openfold.np.relax.relax as relax from openfold.utils.import_weights import ( import_jax_weights_, ) -from openfold.utils.torchscript_utils import script_submodules_ from openfold.utils.tensor_utils import ( tensor_tree_map, ) @@ -45,10 +44,6 @@ from scripts.utils import add_data_args -def script_primitives_(model): - script_submodules_(model, [Attention, GlobalAttention]) - - def main(args): config = model_config(args.model_name) model = AlphaFold(config)