Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Simpler loading of custom models #14

Open
dekuenstle opened this issue Feb 1, 2022 · 2 comments
Open

Feature request: Simpler loading of custom models #14

dekuenstle opened this issue Feb 1, 2022 · 2 comments
Labels
enhancement New feature or request

Comments

@dekuenstle
Copy link

Using your toolbox with the built-in models is straightforward, but we would like to compare some custom pytorch models.
It would be great to have a routine to add these models (i.e. subclasses of nn.Module) to the toolbox registry from your own script. If this is already possible, it would be great if you could share an example.

Currently, we add the model inside the toolbox's files which makes extensions complicated and redundant (e.g. name of
model in the path, the function name, the plotting routine).

Thanks
David

@rgeirhos
Copy link
Member

rgeirhos commented Feb 1, 2022

If you're looking to add many models to the toolbox, the following approach might work (the example is based on a resnet18 at different epochs, but could easily be adapted to other settings):

for epoch in range(0, 90):
    exec(f"""@register_model("pytorch")
def resnet18_epoch_{epoch}(model_name="resnet18", *args):
    model = torch.load(f"/path/to/model/model_name_epoch_{epoch}")
    model.load_state_dict(checkpoint)
    return PytorchModel(model, model_name, *args)
    """)

This snippet can be inserted in modelvshuman/models/pytorch/model_zoo.py and would be a concise version of registering 90 different models. Is this in line with what you had in mind?

@dekuenstle
Copy link
Author

Thanks, this is a work-around that can simplify what we are doing.

My initial request was more towards a function that is available when importing the toolbox.
If you modify the example with a custom model, I would expect the following user experience:

import torch
from modelvshuman import Plot, Evaluate, register_model
from modelvshuman import constants as c
from plotting_definition import plotting_definition_template


def run_evaluation():
    register_model('pytorch', 'crazy-model', torch.load('path/to/model.pth'))
    models = ["resnet50", "bagnet33", "simclr_resnet50x1", 'crazy-model']
    datasets = c.DEFAULT_DATASETS # or e.g. ["cue-conflict", "uniform-noise"]
    params = {"batch_size": 64, "print_predictions": True, "num_workers": 20}
    Evaluate()(models, datasets, **params)

@rgeirhos rgeirhos added the enhancement New feature or request label Feb 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants