Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

How to use pretrained/finetuned VISSL model for inference? #553

Open
ccedar opened this issue Jun 10, 2022 · 1 comment
Open

How to use pretrained/finetuned VISSL model for inference? #553

ccedar opened this issue Jun 10, 2022 · 1 comment

Comments

@ccedar
Copy link

ccedar commented Jun 10, 2022

Training

Here is the command I've executed.

python3 run_distributed_engines.py \
    hydra.verbose=False \
    config=benchmark/fulltune/imagenet1k/eval_resnet_8gpu_transfer_in1k_fulltune \
    config.DATA.TRAIN.DATASET_NAMES=[mydata] \
    config.DATA.TRAIN.DATA_PATHS=["path/to/my/data/train"] \
    config.DATA.TRAIN.COPY_TO_LOCAL_DISK=False \
    config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=32 \
    config.DATA.TEST.DATASET_NAMES=[mydata] \
    config.DATA.TEST.DATA_PATHS=["path/to/my/data/test"] \
    config.DATA.TEST.BATCHSIZE_PER_REPLICA=32 \
    config.DISTRIBUTED.NUM_NODES=1 \
    config.DISTRIBUTED.NUM_PROC_PER_NODE=1 \
    config.CHECKPOINT.DIR="./checkpoints_test" \
    config.MODEL.WEIGHTS_INIT.PARAMS_FILE="resnet50-19c8e357.pth" \
    config.MODEL.WEIGHTS_INIT.APPEND_PREFIX="trunk._feature_blocks." \
    config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME=""

I edited head params' dimension like below. Because I want to classify the images to 5 classes.

MODEL:
    TRUNK:
      NAME: resnet
      TRUNK_PARAMS:
        RESNETS:
          DEPTH: 50
    HEAD:
      PARAMS: [
        ["mlp", {"dims": [2048, 5]}],
      ]

Inference

import os
from vissl.models import build_model
from classy_vision.generic.util import load_checkpoint
from vissl.utils.checkpoint import init_model_from_consolidated_weights
from PIL import Image
import torchvision.transforms as transforms
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict

cfg = [
  'config=benchmark/fulltune/imagenet1k/eval_resnet_8gpu_transfer_in1k_fulltune.yaml',
  'config.MODEL.WEIGHTS_INIT.PARAMS_FILE=checkpoints_0609/model_final_checkpoint_phase599.torch', # Specify path for the model weights.
  'config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True', # Turn on model evaluation mode.
  'config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=False', # Freeze trunk. 
  'config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=True', 
  'config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_TRUNK_AND_HEAD=True', 
  'config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=True', # Extract the trunk features, as opposed to the HEAD.
  'config.MODEL.FEATURE_EVAL_SETTINGS.SHOULD_FLATTEN_FEATS=False', # Do not flatten features.
  'config.MODEL.FEATURE_EVAL_SETTINGS.LINEAR_EVAL_FEAT_POOL_OPS_MAP=[["res5avg", ["Identity", []]]]' # Extract only the res5avg features.
]
# Compose the hydra configuration.
cfg = compose_hydra_configuration(cfg)
# Convert to AttrDict. This method will also infer certain config options
# and validate the config is valid.
_, cfg = convert_to_attrdict(cfg)

model = build_model(cfg.MODEL, cfg.OPTIMIZER)

weights = load_checkpoint(checkpoint_path=cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE)

init_model_from_consolidated_weights(
    config=cfg,
    model=model,
    state_dict=weights,
    state_dict_key_name="classy_state_dict",
    skip_layers=[],  # Use this if you do not want to load all layers
)
print("Weights have loaded")

pipeline = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

img_dir = "/path/to/my/data/test"

for img_name in os.listdir(img_dir):
    img_fname = os.path.join(img_dir, img_name)
    image = Image.open(img_fname).convert("RGB")
    x = pipeline(image)
    features = model(x.unsqueeze(0))
    _, pred = features[0].float().topk(1,  largest=True, sorted=True) 
print(img_fname, features[0][-1], pred[0])

When I run the code like above, I got the results like below. I expected the tensor has length of 5, but the tensor I've got is length of 2048.

/path/to/my/data/test/0.jpg tensor([0.2968, 0.4762, 0.3878,  ..., 0.3190, 0.3820, 0.3187],  grad_fn=<SelectBackward>) tensor([1501])
/path/to/my/data/test/1.jpg tensor([0.3165, 0.4554, 0.4152,  ..., 0.3237, 0.3844, 0.3104],  grad_fn=<SelectBackward>) tensor([1501])
/path/to/my/data/test/3.jpg tensor([0.3076, 0.5146, 0.4472,  ..., 0.3349, 0.3641, 0.3613],  grad_fn=<SelectBackward>) tensor([617])

Could you please point me in the right direction? It's gonna really helpful !! Thank you in advance :)

@QuentinDuval
Copy link
Contributor

Hi @ccedar,

First of all, thank you for using VISSL :)
And sorry for the late answer, I did take a long PTO break lasting almost a month.

So let me first make sure I understood your test case correctly:

  • You first ran a command to do the fine-tuning of a pre-trained Resnet50 to adapt it to 5 classes
  • You then want to extract the predictions fo the fine-tuned ResNet50?

If so, the second script has to be corrected to drop this (this line asks for the features just before the projection head):
'config.MODEL.FEATURE_EVAL_SETTINGS.LINEAR_EVAL_FEAT_POOL_OPS_MAP=[["res5avg", ["Identity", []]]]' # Extract only the res5avg features.

If you remove this line, it should exact the predictions of the head instead (with dimension 5 for 5 classes, instead of 2048).

Thank you,
Quentin

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

2 participants