Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"cut" parameter not passed to vision_learner #4027

Open
adamamer20 opened this issue Apr 25, 2024 · 0 comments · May be fixed by #4028
Open

"cut" parameter not passed to vision_learner #4027

adamamer20 opened this issue Apr 25, 2024 · 0 comments · May be fixed by #4028

Comments

@adamamer20
Copy link

Be sure you've searched the forums for the error message you received. Also, unless you're an experienced fastai developer, first ask on the forums to see if someone else has seen a similar issue already and knows how to solve it. Only file a bug report here when you're quite confident it's not an issue with your local setup.

Please see this model example of how to fill out an issue correctly. Please try to emulate that example as appropriate when opening an issue.

Please confirm you have the latest versions of fastai, fastcore, and nbdev prior to reporting a bug (delete one): YES

Describe the bug
When trying to create a vision_learner with a specified cut parameter, this parameter is ignored by the Learner. The problem is given by the missing "cut" parameter in the model_args dictionary in the vision_learner function

To Reproduce
Steps to reproduce the behavior:

  1. https://gist.github.com/adamamer20/122c54777165a567f9275677a849a534

Expected behavior
There shouldn't be a StopIteration error in the initialization of the Learner given by the "automatic" cut not finding a pooling layer.

Error with full stack trace

Place between these lines with triple backticks:

---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In[4], line 5
      1 timm_model = "hf_hub:anonauthors/food101-timm-vit_base_patch16_224.orig_in21k_ft_in1k"
      3 model = partial(timm.create_model, timm_model, pretrained=True)
----> 5 learn = vision_learner(dls, model, metrics=error_rate, cut=-2)

File c:\ProgramData\miniforge3\envs\finetuning\Lib\site-packages\fastai\vision\learner.py:236, in vision_learner(dls, arch, normalize, n_out, pretrained, weights, loss_func, opt_func, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms, cut, init, custom_head, concat_pool, pool, lin_ftrs, ps, first_bn, bn_final, lin_first, y_range, **kwargs)
    234 else:
    235     if normalize: _add_norm(dls, meta, pretrained, n_in)
--> 236     model = create_vision_model(arch, n_out, pretrained=pretrained, weights=weights, **model_args)
    238 splitter = ifnone(splitter, meta['split'])
    239 learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
    240                metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn, moms=moms)

File c:\ProgramData\miniforge3\envs\finetuning\Lib\site-packages\fastai\vision\learner.py:173, in create_vision_model(arch, n_out, pretrained, weights, cut, n_in, init, custom_head, concat_pool, pool, lin_ftrs, ps, first_bn, bn_final, lin_first, y_range)
    171 else:
    172     model = arch(pretrained=pretrained)
--> 173 body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))
    174 nf = num_features_model(nn.Sequential(*body.children())) if custom_head is None else None
    175 return add_head(body, nf, n_out, init=init, head=custom_head, concat_pool=concat_pool, pool=pool,
    176                 lin_ftrs=lin_ftrs, ps=ps, first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range)

File c:\ProgramData\miniforge3\envs\finetuning\Lib\site-packages\fastai\vision\learner.py:84, in create_body(model, n_in, pretrained, cut)
     82 if cut is None:
     83     ll = list(enumerate(model.children()))
---> 84     cut = next(i for i,o in reversed(ll) if has_pool_type(o))
     85 return cut_model(model, cut)

StopIteration: 

Additional context
Add any other context about the problem here.

@adamamer20 adamamer20 linked a pull request Apr 25, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant