Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The pre-trained GearNet-Edge model for Fold Classification #41

Open
arantir123 opened this issue Jul 25, 2023 · 5 comments
Open

The pre-trained GearNet-Edge model for Fold Classification #41

arantir123 opened this issue Jul 25, 2023 · 5 comments

Comments

@arantir123
Copy link

arantir123 commented Jul 25, 2023

Thank you for your amazing work! I found that for the Fold Classification task, the GearNet-Edge model was implemented based on the GearNetIEConv script rather than the GearNet script, which has some detail differences (e.g., extra input embedding and ieconv layers). Based on this, I would like to ask whether you could provide the pretrained GearNet-Edge model based on multiview contrast learning and the GearNetIEConv script for Fold Classification (rather than based on GearNet script for EC task)? Thank you.

@Oxer11
Copy link
Collaborator

Oxer11 commented Jul 31, 2023

Hi, the config file for GearNet-Edge-IEConv on Fold is config/Fold3D/gearnet_edge_ieconv.yaml. The pre-trained checkpoints of GearNet-Edge can be found at https://zenodo.org/record/7723075.

@arantir123
Copy link
Author

Thank you. It seems that fold_mc_gearnet_edge_ieconv.pth includes the encoder and decoder parameters after finetuning. I just would like to do some experiments on my own, i.e., I would like to have the pretrained GearNet-Edge-IEConv encoder before finetuning, obtain the finetuning configuration script and corresponding running command (e.g., how many GPUs/batch size were actually used in finetuning), and do the finetuning experiment on my own. Whether it is convenient for you to provide these for me? Thank you very much.

@Oxer11
Copy link
Collaborator

Oxer11 commented Aug 3, 2023

I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.

For finetuning, just use the following command

python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>

@arantir123
Copy link
Author

I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.

For finetuning, just use the following command

python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>

Thank you very much. I will have a try.

@arantir123
Copy link
Author

I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.

For finetuning, just use the following command

python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>

Hi, it seems that the model contained in the above link is not in line with/cannot fit the model (size) in official https://zenodo.org/record/7723075 (the hidden dimensions of each layer are different), I guess the model in https://zenodo.org/record/7723075 is based on the following new implementation version of GearNet-Edge-IEConv (with extra input embedding etc).

@R.register("models.GearNetIEConv")
class GearNetIEConv(nn.Module, core.Configurable):

def __init__(self, input_dim, embedding_dim, hidden_dims, num_relation, edge_input_dim=None,
             batch_norm=False, activation="relu", concat_hidden=False, short_cut=True, 
             readout="sum", dropout=0, num_angle_bin=None, layer_norm=False, use_ieconv=False):
    super(GearNetIEConv, self).__init__()
    print('using GearNetIEConv.')

    if not isinstance(hidden_dims, Sequence):
        hidden_dims = [hidden_dims]
    self.input_dim = input_dim
    self.embedding_dim = embedding_dim
    self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
    self.dims = [embedding_dim if embedding_dim > 0 else input_dim] + list(hidden_dims)
    self.edge_dims = [edge_input_dim] + self.dims[:-1]
    self.num_relation = num_relation
    self.concat_hidden = concat_hidden
    self.short_cut = short_cut
    self.num_angle_bin = num_angle_bin
    self.short_cut = short_cut
    self.concat_hidden = concat_hidden
    self.layer_norm = layer_norm
    self.use_ieconv = use_ieconv  

    if embedding_dim > 0:
        self.linear = nn.Linear(input_dim, embedding_dim)
        self.embedding_batch_norm = nn.BatchNorm1d(embedding_dim)

    self.layers = nn.ModuleList()
    self.ieconvs = nn.ModuleList()
    for i in range(len(self.dims) - 1):
        # note that these layers are from gearnet.layer instead of torchdrug.layers
        self.layers.append(layer.GeometricRelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation,
                                                               None, batch_norm, activation))
        if use_ieconv:
            self.ieconvs.append(layer.IEConvLayer(self.dims[i], self.dims[i] // 4, 
                                self.dims[i+1], edge_input_dim=14, kernel_hidden_dim=32))
    if num_angle_bin:
        self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
        self.edge_layers = nn.ModuleList()
        for i in range(len(self.edge_dims) - 1):
            self.edge_layers.append(layer.GeometricRelationalGraphConv(
                self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

    if layer_norm:
        self.layer_norms = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            self.layer_norms.append(nn.LayerNorm(self.dims[i + 1]))

    self.dropout = nn.Dropout(dropout)

    if readout == "sum":
        self.readout = layers.SumReadout()
    elif readout == "mean":
        self.readout = layers.MeanReadout()
    else:
        raise ValueError("Unknown readout `%s`" % readout)

def get_ieconv_edge_feature(self, graph):
    u = torch.ones_like(graph.node_position)
    u[1:] = graph.node_position[1:] - graph.node_position[:-1]
    u = F.normalize(u, dim=-1)
    b = torch.ones_like(graph.node_position)
    b[:-1] = u[:-1] - u[1:]
    b = F.normalize(b, dim=-1)
    n = torch.ones_like(graph.node_position)
    n[:-1] = torch.cross(u[:-1], u[1:])
    n = F.normalize(n, dim=-1)

    local_frame = torch.stack([b, n, torch.cross(b, n)], dim=-1)

    node_in, node_out = graph.edge_list.t()[:2]
    t = graph.node_position[node_out] - graph.node_position[node_in]
    t = torch.einsum('ijk, ij->ik', local_frame[node_in], t)
    r = torch.sum(local_frame[node_in] * local_frame[node_out], dim=1)
    delta = torch.abs(graph.atom2residue[node_in] - graph.atom2residue[node_out]).float() / 6
    delta = delta.unsqueeze(-1)

    return torch.cat([
        t, r, delta, 
        1 - 2 * t.abs(), 1 - 2 * r.abs(), 1 - 2 * delta.abs()
    ], dim=-1)

def forward(self, graph, input, all_loss=None, metric=None):
    hiddens = []
    layer_input = input
    if self.embedding_dim > 0:
        layer_input = self.linear(layer_input)
        layer_input = self.embedding_batch_norm(layer_input)
    if self.num_angle_bin:
        line_graph = self.spatial_line_graph(graph)
        edge_hidden = line_graph.node_feature.float()
    else:
        edge_hidden = None
    ieconv_edge_feature = self.get_ieconv_edge_feature(graph)

    for i in range(len(self.layers)):
        # edge message passing
        if self.num_angle_bin:
            edge_hidden = self.edge_layers[i](line_graph, edge_hidden)
        hidden = self.layers[i](graph, layer_input, edge_hidden)
        # ieconv layer
        if self.use_ieconv:
            hidden = hidden + self.ieconvs[i](graph, layer_input, ieconv_edge_feature)
        hidden = self.dropout(hidden)

        if self.short_cut and hidden.shape == layer_input.shape:
            hidden = hidden + layer_input

        if self.layer_norm:
            hidden = self.layer_norms[i](hidden)
        hiddens.append(hidden)
        layer_input = hidden

    if self.concat_hidden:
        node_feature = torch.cat(hiddens, dim=-1)
    else:
        node_feature = hiddens[-1]
    graph_feature = self.readout(graph, node_feature)

    return {
        "graph_feature": graph_feature,
        "node_feature": node_feature
    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants