Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Added support for timm in unet #3717

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions fastai/_modidx.py
Expand Up @@ -1402,6 +1402,7 @@
'fastai.vision.learner.create_vision_model': 'https://docs.fast.ai/vision.learner.html#create_vision_model',
'fastai.vision.learner.cut_model': 'https://docs.fast.ai/vision.learner.html#cut_model',
'fastai.vision.learner.default_split': 'https://docs.fast.ai/vision.learner.html#default_split',
'fastai.vision.learner.get_timm_meta': 'https://docs.fast.ai/vision.learner.html#get_timm_meta',
'fastai.vision.learner.has_pool_type': 'https://docs.fast.ai/vision.learner.html#has_pool_type',
'fastai.vision.learner.model_meta': 'https://docs.fast.ai/vision.learner.html#model_meta',
'fastai.vision.learner.plot_top_losses': 'https://docs.fast.ai/vision.learner.html#plot_top_losses',
Expand Down
62 changes: 44 additions & 18 deletions fastai/vision/learner.py
Expand Up @@ -14,7 +14,7 @@
# %% auto 0
__all__ = ['model_meta', 'has_pool_type', 'cut_model', 'create_body', 'create_head', 'default_split', 'add_head',
'create_vision_model', 'TimmBody', 'create_timm_model', 'vision_learner', 'create_unet_model',
'unet_learner', 'create_cnn_model', 'cnn_learner', 'show_results', 'plot_top_losses']
'get_timm_meta', 'unet_learner', 'create_cnn_model', 'cnn_learner', 'show_results', 'plot_top_losses']

# %% ../nbs/21_vision.learner.ipynb 8
def _is_pool_type(l): return re.search(r'Pool[123]d$', l.__class__.__name__)
Expand Down Expand Up @@ -236,32 +236,58 @@ def vision_learner(dls, arch, normalize=True, n_out=None, pretrained=True,
@delegates(models.unet.DynamicUnet.__init__)
def create_unet_model(arch, n_out, img_size, pretrained=True, cut=None, n_in=3, **kwargs):
"Create custom unet architecture"
meta = model_meta.get(arch, _default_meta)
model = arch(pretrained=pretrained)
body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))
if isinstance(arch, str):
body = timm.create_model(
arch,
pretrained=pretrained,
features_only=True,
num_classes=0,
in_chans=n_in,
)
body = nn.Sequential(*list(body.children()))
meta = get_timm_meta(arch, cut)
else:
meta = model_meta.get(arch, _default_meta)
model = arch(pretrained=pretrained)
body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))

model = models.unet.DynamicUnet(body, n_out, img_size, **kwargs)
return model

# %% ../nbs/21_vision.learner.ipynb 53
def get_timm_meta(arch: str, cut = None) -> Dict:
meta = dict(_default_meta)
if cut is not None:
meta.update({"cut": cut})
for mm in list(model_meta.keys()):
if mm.__name__ == arch:
meta.update(model_meta[mm])
return meta

# %% ../nbs/21_vision.learner.ipynb 54
@delegates(create_unet_model)
def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,
# learner args
loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), **kwargs):
model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), n_in=3, cut=None, **kwargs):
"Build a unet learner from `dls` and `arch`"

if config:
warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')
kwargs = {**config, **kwargs}

meta = model_meta.get(arch, _default_meta)
if normalize: _add_norm(dls, meta, pretrained)

n_out = ifnone(n_out, get_c(dls))
assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
img_size = dls.one_batch()[0].shape[-2:]
assert img_size, "image size could not be inferred from data"
model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)

if isinstance(arch, str):
meta = get_timm_meta(arch, cut)
else:
meta = model_meta.get(arch, _default_meta)
if normalize: _add_norm(dls, meta, pretrained)

model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, cut=ifnone(cut, meta["cut"]), n_in=n_in, **kwargs)

splitter=ifnone(splitter, meta['split'])
learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
Expand All @@ -272,26 +298,26 @@ def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=
store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)
return learn

# %% ../nbs/21_vision.learner.ipynb 58
# %% ../nbs/21_vision.learner.ipynb 61
def create_cnn_model(*args, **kwargs):
"Deprecated name for `create_vision_model` -- do not use"
warn("`create_cnn_model` has been renamed to `create_vision_model` -- please update your code")
return create_vision_model(*args, **kwargs)

# %% ../nbs/21_vision.learner.ipynb 59
# %% ../nbs/21_vision.learner.ipynb 62
def cnn_learner(*args, **kwargs):
"Deprecated name for `vision_learner` -- do not use"
warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")
return vision_learner(*args, **kwargs)

# %% ../nbs/21_vision.learner.ipynb 61
# %% ../nbs/21_vision.learner.ipynb 64
@typedispatch
def show_results(x:TensorImage, y, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)
return ctxs

# %% ../nbs/21_vision.learner.ipynb 62
# %% ../nbs/21_vision.learner.ipynb 65
@typedispatch
def show_results(x:TensorImage, y:TensorCategory, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
Expand All @@ -301,7 +327,7 @@ def show_results(x:TensorImage, y:TensorCategory, samples, outs, ctxs=None, max_
for b,r,c,_ in zip(samples.itemgot(1),outs.itemgot(0),ctxs,range(max_n))]
return ctxs

# %% ../nbs/21_vision.learner.ipynb 63
# %% ../nbs/21_vision.learner.ipynb 66
@typedispatch
def show_results(x:TensorImage, y:TensorMask|TensorPoint|TensorBBox, samples, outs, ctxs=None, max_n=6,
nrows=None, ncols=1, figsize=None, **kwargs):
Expand All @@ -313,7 +339,7 @@ def show_results(x:TensorImage, y:TensorMask|TensorPoint|TensorBBox, samples, ou
ctxs[1::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(o.itemgot(0),ctxs[1::2],range(2*max_n))]
return ctxs

# %% ../nbs/21_vision.learner.ipynb 64
# %% ../nbs/21_vision.learner.ipynb 67
@typedispatch
def show_results(x:TensorImage, y:TensorImage, samples, outs, ctxs=None, max_n=10, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(3*min(len(samples), max_n), ncols=3, figsize=figsize, title='Input/Target/Prediction')
Expand All @@ -322,15 +348,15 @@ def show_results(x:TensorImage, y:TensorImage, samples, outs, ctxs=None, max_n=1
ctxs[2::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs[2::3],range(max_n))]
return ctxs

# %% ../nbs/21_vision.learner.ipynb 65
# %% ../nbs/21_vision.learner.ipynb 68
@typedispatch
def plot_top_losses(x: TensorImage, y:TensorCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):
axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize, title='Prediction/Actual/Loss/Probability')
for ax,s,o,r,l in zip(axs, samples, outs, raws, losses):
s[0].show(ctx=ax, **kwargs)
ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')

# %% ../nbs/21_vision.learner.ipynb 66
# %% ../nbs/21_vision.learner.ipynb 69
@typedispatch
def plot_top_losses(x: TensorImage, y:TensorMultiCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):
axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize)
Expand All @@ -341,7 +367,7 @@ def plot_top_losses(x: TensorImage, y:TensorMultiCategory, samples, outs, raws,
rows = [b.show(ctx=r, label=l, **kwargs) for b,r in zip(outs.itemgot(i),rows)]
display_df(pd.DataFrame(rows))

# %% ../nbs/21_vision.learner.ipynb 67
# %% ../nbs/21_vision.learner.ipynb 70
@typedispatch
def plot_top_losses(x:TensorImage, y:TensorMask, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):
axes = get_grid(len(samples)*3, nrows=len(samples), ncols=3, figsize=figsize, flatten=False, title="Input | Target | Prediction")
Expand Down
99 changes: 65 additions & 34 deletions nbs/21_vision.learner.ipynb
Expand Up @@ -302,7 +302,7 @@
" (ap): AdaptiveAvgPool2d(output_size=1)\n",
" (mp): AdaptiveMaxPool2d(output_size=1)\n",
" )\n",
" (1): Flatten(full=False)\n",
" (1): fastai.layers.Flatten(full=False)\n",
" (2): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Dropout(p=0.25, inplace=False)\n",
" (4): Linear(in_features=10, out_features=512, bias=False)\n",
Expand Down Expand Up @@ -493,7 +493,7 @@
"#### create_vision_model\n",
"\n",
"> create_vision_model (arch, n_out, pretrained=True, cut=None, n_in=3,\n",
"> init=<functionkaiming_normal_at0x10aac2f70>,\n",
"> init=<functionkaiming_normal_at0x7f92922c3280>,\n",
"> custom_head=None, concat_pool=True, pool=True,\n",
"> lin_ftrs=None, ps=0.5, first_bn=True,\n",
"> bn_final=False, lin_first=False, y_range=None)\n",
Expand Down Expand Up @@ -524,18 +524,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n",
" warnings.warn(\n",
"/Users/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"outputs": [],
"source": [
"tst = create_vision_model(models.resnet18, 10, True)\n",
"tst = create_vision_model(models.resnet18, 10, True, n_in=1)"
Expand Down Expand Up @@ -697,16 +686,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"outputs": [],
"source": [
"#|hide\n",
"learn = vision_learner(dls, models.resnet34, loss_func=CrossEntropyLossFlat(), ps=0.25, concat_pool=False)\n",
Expand Down Expand Up @@ -751,9 +731,21 @@
"@delegates(models.unet.DynamicUnet.__init__)\n",
"def create_unet_model(arch, n_out, img_size, pretrained=True, cut=None, n_in=3, **kwargs):\n",
" \"Create custom unet architecture\"\n",
" meta = model_meta.get(arch, _default_meta)\n",
" model = arch(pretrained=pretrained)\n",
" body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut'])) \n",
" if isinstance(arch, str):\n",
" body = timm.create_model(\n",
" arch,\n",
" pretrained=pretrained,\n",
" features_only=True,\n",
" num_classes=0,\n",
" in_chans=n_in,\n",
" )\n",
" body = nn.Sequential(*list(body.children()))\n",
" meta = get_timm_meta(arch, cut)\n",
" else:\n",
" meta = model_meta.get(arch, _default_meta)\n",
" model = arch(pretrained=pretrained)\n",
" body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut'])) \n",
"\n",
" model = models.unet.DynamicUnet(body, n_out, img_size, **kwargs)\n",
" return model"
]
Expand All @@ -775,7 +767,7 @@
"> self_attention=False, y_range=None, last_cross=True,\n",
"> bottle=False,\n",
"> act_cls=<class'torch.nn.modules.activation.ReLU'>,\n",
"> init=<functionkaiming_normal_at0x10aac2f70>,\n",
"> init=<functionkaiming_normal_at0x7f92922c3280>,\n",
"> norm_type=None)\n",
"\n",
"Create custom unet architecture\n",
Expand Down Expand Up @@ -820,6 +812,23 @@
"tst = create_unet_model(models.resnet18, 10, (24,24), True, n_in=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"def get_timm_meta(arch: str, cut = None) -> Dict:\n",
" meta = dict(_default_meta)\n",
" if cut is not None:\n",
" meta.update({\"cut\": cut})\n",
" for mm in list(model_meta.keys()):\n",
" if mm.__name__ == arch:\n",
" meta.update(model_meta[mm])\n",
" return meta"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -831,21 +840,25 @@
"def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,\n",
" # learner args\n",
" loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,\n",
" model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), **kwargs): \n",
" model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), n_in=3, cut=None, **kwargs): \n",
" \"Build a unet learner from `dls` and `arch`\"\n",
" \n",
" if config:\n",
" warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')\n",
" kwargs = {**config, **kwargs}\n",
" \n",
" meta = model_meta.get(arch, _default_meta)\n",
" if normalize: _add_norm(dls, meta, pretrained)\n",
" \n",
" n_out = ifnone(n_out, get_c(dls))\n",
" assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
" img_size = dls.one_batch()[0].shape[-2:]\n",
" assert img_size, \"image size could not be inferred from data\"\n",
" model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)\n",
"\n",
" if isinstance(arch, str):\n",
" meta = get_timm_meta(arch, cut)\n",
" else:\n",
" meta = model_meta.get(arch, _default_meta)\n",
" if normalize: _add_norm(dls, meta, pretrained)\n",
"\n",
" model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, cut=ifnone(cut, meta[\"cut\"]), n_in=n_in, **kwargs)\n",
"\n",
" splitter=ifnone(splitter, meta['split'])\n",
" learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,\n",
Expand Down Expand Up @@ -890,6 +903,24 @@
"learn = unet_learner(dls, models.resnet34, loss_func=CrossEntropyLossFlat(axis=1), y_range=(0,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = unet_learner(dls, \"resnet18\", loss_func=CrossEntropyLossFlat(axis=1), y_range=(0,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = unet_learner(dls, \"convnext_tiny\", loss_func=CrossEntropyLossFlat(axis=1), y_range=(0,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1067,7 +1098,7 @@
"source": [
"#|hide\n",
"from nbdev import nbdev_export\n",
"nbdev_export()"
"nbdev_export(\"21_vision.learner.ipynb\")"
]
},
{
Expand Down