Skip to content

Commit

Permalink
Added initial timm support for unet learner
Browse files Browse the repository at this point in the history
  • Loading branch information
madhavajay committed Aug 8, 2022
1 parent f197719 commit 543a530
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 52 deletions.
1 change: 1 addition & 0 deletions fastai/_modidx.py
Expand Up @@ -1402,6 +1402,7 @@
'fastai.vision.learner.create_vision_model': 'https://docs.fast.ai/vision.learner.html#create_vision_model',
'fastai.vision.learner.cut_model': 'https://docs.fast.ai/vision.learner.html#cut_model',
'fastai.vision.learner.default_split': 'https://docs.fast.ai/vision.learner.html#default_split',
'fastai.vision.learner.get_timm_meta': 'https://docs.fast.ai/vision.learner.html#get_timm_meta',
'fastai.vision.learner.has_pool_type': 'https://docs.fast.ai/vision.learner.html#has_pool_type',
'fastai.vision.learner.model_meta': 'https://docs.fast.ai/vision.learner.html#model_meta',
'fastai.vision.learner.plot_top_losses': 'https://docs.fast.ai/vision.learner.html#plot_top_losses',
Expand Down
62 changes: 44 additions & 18 deletions fastai/vision/learner.py
Expand Up @@ -14,7 +14,7 @@
# %% auto 0
__all__ = ['model_meta', 'has_pool_type', 'cut_model', 'create_body', 'create_head', 'default_split', 'add_head',
'create_vision_model', 'TimmBody', 'create_timm_model', 'vision_learner', 'create_unet_model',
'unet_learner', 'create_cnn_model', 'cnn_learner', 'show_results', 'plot_top_losses']
'get_timm_meta', 'unet_learner', 'create_cnn_model', 'cnn_learner', 'show_results', 'plot_top_losses']

# %% ../nbs/21_vision.learner.ipynb 8
def _is_pool_type(l): return re.search(r'Pool[123]d$', l.__class__.__name__)
Expand Down Expand Up @@ -236,32 +236,58 @@ def vision_learner(dls, arch, normalize=True, n_out=None, pretrained=True,
@delegates(models.unet.DynamicUnet.__init__)
def create_unet_model(arch, n_out, img_size, pretrained=True, cut=None, n_in=3, **kwargs):
"Create custom unet architecture"
meta = model_meta.get(arch, _default_meta)
model = arch(pretrained=pretrained)
body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))
if isinstance(arch, str):
body = timm.create_model(
arch,
pretrained=pretrained,
features_only=True,
num_classes=0,
in_chans=n_in,
)
body = nn.Sequential(*list(body.children()))
meta = get_timm_meta(arch, cut)
else:
meta = model_meta.get(arch, _default_meta)
model = arch(pretrained=pretrained)
body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))

model = models.unet.DynamicUnet(body, n_out, img_size, **kwargs)
return model

# %% ../nbs/21_vision.learner.ipynb 53
def get_timm_meta(arch: str, cut = None) -> Dict:
meta = dict(_default_meta)
if cut is not None:
meta.update({"cut": cut})
for mm in list(model_meta.keys()):
if mm.__name__ == arch:
meta.update(model_meta[mm])
return meta

# %% ../nbs/21_vision.learner.ipynb 54
@delegates(create_unet_model)
def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,
# learner args
loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), **kwargs):
model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), n_in=3, cut=None, **kwargs):
"Build a unet learner from `dls` and `arch`"

if config:
warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')
kwargs = {**config, **kwargs}

meta = model_meta.get(arch, _default_meta)
if normalize: _add_norm(dls, meta, pretrained)

n_out = ifnone(n_out, get_c(dls))
assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
img_size = dls.one_batch()[0].shape[-2:]
assert img_size, "image size could not be inferred from data"
model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)

if isinstance(arch, str):
meta = get_timm_meta(arch, cut)
else:
meta = model_meta.get(arch, _default_meta)
if normalize: _add_norm(dls, meta, pretrained)

model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, cut=ifnone(cut, meta["cut"]), n_in=n_in, **kwargs)

splitter=ifnone(splitter, meta['split'])
learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
Expand All @@ -272,26 +298,26 @@ def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=
store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)
return learn

# %% ../nbs/21_vision.learner.ipynb 58
# %% ../nbs/21_vision.learner.ipynb 61
def create_cnn_model(*args, **kwargs):
"Deprecated name for `create_vision_model` -- do not use"
warn("`create_cnn_model` has been renamed to `create_vision_model` -- please update your code")
return create_vision_model(*args, **kwargs)

# %% ../nbs/21_vision.learner.ipynb 59
# %% ../nbs/21_vision.learner.ipynb 62
def cnn_learner(*args, **kwargs):
"Deprecated name for `vision_learner` -- do not use"
warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")
return vision_learner(*args, **kwargs)

# %% ../nbs/21_vision.learner.ipynb 61
# %% ../nbs/21_vision.learner.ipynb 64
@typedispatch
def show_results(x:TensorImage, y, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)
return ctxs

# %% ../nbs/21_vision.learner.ipynb 62
# %% ../nbs/21_vision.learner.ipynb 65
@typedispatch
def show_results(x:TensorImage, y:TensorCategory, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
Expand All @@ -301,7 +327,7 @@ def show_results(x:TensorImage, y:TensorCategory, samples, outs, ctxs=None, max_
for b,r,c,_ in zip(samples.itemgot(1),outs.itemgot(0),ctxs,range(max_n))]
return ctxs

# %% ../nbs/21_vision.learner.ipynb 63
# %% ../nbs/21_vision.learner.ipynb 66
@typedispatch
def show_results(x:TensorImage, y:TensorMask|TensorPoint|TensorBBox, samples, outs, ctxs=None, max_n=6,
nrows=None, ncols=1, figsize=None, **kwargs):
Expand All @@ -313,7 +339,7 @@ def show_results(x:TensorImage, y:TensorMask|TensorPoint|TensorBBox, samples, ou
ctxs[1::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(o.itemgot(0),ctxs[1::2],range(2*max_n))]
return ctxs

# %% ../nbs/21_vision.learner.ipynb 64
# %% ../nbs/21_vision.learner.ipynb 67
@typedispatch
def show_results(x:TensorImage, y:TensorImage, samples, outs, ctxs=None, max_n=10, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(3*min(len(samples), max_n), ncols=3, figsize=figsize, title='Input/Target/Prediction')
Expand All @@ -322,15 +348,15 @@ def show_results(x:TensorImage, y:TensorImage, samples, outs, ctxs=None, max_n=1
ctxs[2::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs[2::3],range(max_n))]
return ctxs

# %% ../nbs/21_vision.learner.ipynb 65
# %% ../nbs/21_vision.learner.ipynb 68
@typedispatch
def plot_top_losses(x: TensorImage, y:TensorCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):
axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize, title='Prediction/Actual/Loss/Probability')
for ax,s,o,r,l in zip(axs, samples, outs, raws, losses):
s[0].show(ctx=ax, **kwargs)
ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')

# %% ../nbs/21_vision.learner.ipynb 66
# %% ../nbs/21_vision.learner.ipynb 69
@typedispatch
def plot_top_losses(x: TensorImage, y:TensorMultiCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):
axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize)
Expand All @@ -341,7 +367,7 @@ def plot_top_losses(x: TensorImage, y:TensorMultiCategory, samples, outs, raws,
rows = [b.show(ctx=r, label=l, **kwargs) for b,r in zip(outs.itemgot(i),rows)]
display_df(pd.DataFrame(rows))

# %% ../nbs/21_vision.learner.ipynb 67
# %% ../nbs/21_vision.learner.ipynb 70
@typedispatch
def plot_top_losses(x:TensorImage, y:TensorMask, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):
axes = get_grid(len(samples)*3, nrows=len(samples), ncols=3, figsize=figsize, flatten=False, title="Input | Target | Prediction")
Expand Down
99 changes: 65 additions & 34 deletions nbs/21_vision.learner.ipynb
Expand Up @@ -302,7 +302,7 @@
" (ap): AdaptiveAvgPool2d(output_size=1)\n",
" (mp): AdaptiveMaxPool2d(output_size=1)\n",
" )\n",
" (1): Flatten(full=False)\n",
" (1): fastai.layers.Flatten(full=False)\n",
" (2): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Dropout(p=0.25, inplace=False)\n",
" (4): Linear(in_features=10, out_features=512, bias=False)\n",
Expand Down Expand Up @@ -493,7 +493,7 @@
"#### create_vision_model\n",
"\n",
"> create_vision_model (arch, n_out, pretrained=True, cut=None, n_in=3,\n",
"> init=<functionkaiming_normal_at0x10aac2f70>,\n",
"> init=<functionkaiming_normal_at0x7f92922c3280>,\n",
"> custom_head=None, concat_pool=True, pool=True,\n",
"> lin_ftrs=None, ps=0.5, first_bn=True,\n",
"> bn_final=False, lin_first=False, y_range=None)\n",
Expand Down Expand Up @@ -524,18 +524,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n",
" warnings.warn(\n",
"/Users/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"outputs": [],
"source": [
"tst = create_vision_model(models.resnet18, 10, True)\n",
"tst = create_vision_model(models.resnet18, 10, True, n_in=1)"
Expand Down Expand Up @@ -697,16 +686,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"outputs": [],
"source": [
"#|hide\n",
"learn = vision_learner(dls, models.resnet34, loss_func=CrossEntropyLossFlat(), ps=0.25, concat_pool=False)\n",
Expand Down Expand Up @@ -751,9 +731,21 @@
"@delegates(models.unet.DynamicUnet.__init__)\n",
"def create_unet_model(arch, n_out, img_size, pretrained=True, cut=None, n_in=3, **kwargs):\n",
" \"Create custom unet architecture\"\n",
" meta = model_meta.get(arch, _default_meta)\n",
" model = arch(pretrained=pretrained)\n",
" body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut'])) \n",
" if isinstance(arch, str):\n",
" body = timm.create_model(\n",
" arch,\n",
" pretrained=pretrained,\n",
" features_only=True,\n",
" num_classes=0,\n",
" in_chans=n_in,\n",
" )\n",
" body = nn.Sequential(*list(body.children()))\n",
" meta = get_timm_meta(arch, cut)\n",
" else:\n",
" meta = model_meta.get(arch, _default_meta)\n",
" model = arch(pretrained=pretrained)\n",
" body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut'])) \n",
"\n",
" model = models.unet.DynamicUnet(body, n_out, img_size, **kwargs)\n",
" return model"
]
Expand All @@ -775,7 +767,7 @@
"> self_attention=False, y_range=None, last_cross=True,\n",
"> bottle=False,\n",
"> act_cls=<class'torch.nn.modules.activation.ReLU'>,\n",
"> init=<functionkaiming_normal_at0x10aac2f70>,\n",
"> init=<functionkaiming_normal_at0x7f92922c3280>,\n",
"> norm_type=None)\n",
"\n",
"Create custom unet architecture\n",
Expand Down Expand Up @@ -820,6 +812,23 @@
"tst = create_unet_model(models.resnet18, 10, (24,24), True, n_in=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"def get_timm_meta(arch: str, cut = None) -> Dict:\n",
" meta = dict(_default_meta)\n",
" if cut is not None:\n",
" meta.update({\"cut\": cut})\n",
" for mm in list(model_meta.keys()):\n",
" if mm.__name__ == arch:\n",
" meta.update(model_meta[mm])\n",
" return meta"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -831,21 +840,25 @@
"def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,\n",
" # learner args\n",
" loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,\n",
" model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), **kwargs): \n",
" model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), n_in=3, cut=None, **kwargs): \n",
" \"Build a unet learner from `dls` and `arch`\"\n",
" \n",
" if config:\n",
" warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')\n",
" kwargs = {**config, **kwargs}\n",
" \n",
" meta = model_meta.get(arch, _default_meta)\n",
" if normalize: _add_norm(dls, meta, pretrained)\n",
" \n",
" n_out = ifnone(n_out, get_c(dls))\n",
" assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
" img_size = dls.one_batch()[0].shape[-2:]\n",
" assert img_size, \"image size could not be inferred from data\"\n",
" model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)\n",
"\n",
" if isinstance(arch, str):\n",
" meta = get_timm_meta(arch, cut)\n",
" else:\n",
" meta = model_meta.get(arch, _default_meta)\n",
" if normalize: _add_norm(dls, meta, pretrained)\n",
"\n",
" model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, cut=ifnone(cut, meta[\"cut\"]), n_in=n_in, **kwargs)\n",
"\n",
" splitter=ifnone(splitter, meta['split'])\n",
" learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,\n",
Expand Down Expand Up @@ -890,6 +903,24 @@
"learn = unet_learner(dls, models.resnet34, loss_func=CrossEntropyLossFlat(axis=1), y_range=(0,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = unet_learner(dls, \"resnet18\", loss_func=CrossEntropyLossFlat(axis=1), y_range=(0,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = unet_learner(dls, \"convnext_tiny\", loss_func=CrossEntropyLossFlat(axis=1), y_range=(0,1))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1067,7 +1098,7 @@
"source": [
"#|hide\n",
"from nbdev import nbdev_export\n",
"nbdev_export()"
"nbdev_export(\"21_vision.learner.ipynb\")"
]
},
{
Expand Down

0 comments on commit 543a530

Please sign in to comment.