Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

NPID + MoCoV2 weights are the same? #574

Open
ColinConwell opened this issue Oct 25, 2022 · 1 comment
Open

NPID + MoCoV2 weights are the same? #574

ColinConwell opened this issue Oct 25, 2022 · 1 comment
Assignees

Comments

@ColinConwell
Copy link

The weights from the download URLs for NPID and MoCoV2 appear to be the same. Perhaps a copying error?

The code below may be run to demonstrate the equivalence:

def get_vissl_model(weights_url):
    from torch.hub import load_state_dict_from_url
    weights = load_state_dict_from_url(weights_url, map_location = torch.device('cpu'))
    
    def replace_module_prefix(state_dict, prefix, replace_with = ''):
        return {(key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key): val
                      for (key, val) in state_dict.items()}

    def convert_model_weights(model):
        if "classy_state_dict" in model.keys():
            model_trunk = model["classy_state_dict"]["base_model"]["model"]["trunk"]
        elif "model_state_dict" in model.keys():
            model_trunk = model["model_state_dict"]
        else:
            model_trunk = model
        return replace_module_prefix(model_trunk, "_feature_blocks.")

    converted_weights = convert_model_weights(weights)
    excess_weights = ['fc','projection', 'prototypes']
    converted_weights = {key:value for (key,value) in converted_weights.items()
                             if not any([x in key for x in excess_weights])}
    
    if 'module' in next(iter(converted_weights)):
        converted_weights = {key.replace('module.',''):value for (key,value) in converted_weights.items()
                             if 'fc' not in key}
        
    from torchvision.models import resnet50
    import torch.nn as nn

    class Identity(nn.Module):
        def __init__(self):
            super(Identity, self).__init__()

        def forward(self, x):
            return x

    model = resnet50()
    model.fc = Identity()

    model.load_state_dict(converted_weights)
    
    return model

### NPID 
weights_url = 'https://dl.fbaipublicfiles.com/vissl/model_zoo/npid_1node_200ep_4kneg_npid_8gpu_resnet_23_07_20.9eb36512/model_final_checkpoint_phase199.torch'
model = get_vissl_model(weights_url)
print(model.parameters())[1:10,1,1,1])

### MoCoV2 
weights_url = 'https://dl.fbaipublicfiles.com/vissl/model_zoo/moco_v2_1node_lr.03_step_b32_zero_init/model_final_checkpoint_phase199.torch'
model = get_vissl_model(weights_url)
print(model.parameters())[1:10,1,1,1])

### BarlowTwins to show the difference
weights_url = 'https://dl.fbaipublicfiles.com/vissl/model_zoo/barlow_twins/barlow_twins_32gpus_4node_imagenet1k_1000ep_resnet50.torch'
model = get_vissl_model(weights_url)
print(model.parameters())[1:10,1,1,1])
@QuentinDuval
Copy link
Contributor

Hi @ColinConwell,

Thanks a lot for using VISSL and thanks a lot for raising this issue :)

Let me check this and come back to you !

Thank you,
Quentin

@QuentinDuval QuentinDuval self-assigned this Jan 6, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants