Skip to content

Commit

Permalink
Add learner.predict() usage example
Browse files Browse the repository at this point in the history
  • Loading branch information
bencoman committed Nov 27, 2022
1 parent 6fc2502 commit 8c5074b
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions nbs/13a_learner.ipynb
Expand Up @@ -231,8 +231,7 @@
"source": [
"`file` can be a `Path` object, a string or an opened file object. If a `device` is passed, the model is loaded on it, otherwise it's loaded on the CPU. \n",
"\n",
"If `strict` is `True`, the file must exactly contain weights for every parameter key in `model`, if `strict` is `False`, only the keys that are in the saved model are loaded in `model`.",
"\n",
"If `strict` is `True`, the file must exactly contain weights for every parameter key in `model`, if `strict` is `False`, only the keys that are in the saved model are loaded in `model`.\n",
"You can pass in other kwargs to `torch.load` through `torch_load_kwargs`."
]
},
Expand Down Expand Up @@ -3803,10 +3802,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"It returns a tuple of three elements with, in reverse order,\n",
"- the prediction from the model, potentially passed through the activation of the loss function (if it has one)\n",
"- the decoded prediction, using the potential <code>decodes</code> method from it\n",
"- the fully decoded prediction, using the transforms used to build the `Datasets`/`DataLoaders`"
"It returns a tuple of three elements with, in reverse order:\n",
"- probs[] = the model's predictions (output layer), potentially passed through the activation of the loss function (if it has one); \n",
"- ndx = the decoded prediction, using the potential <code>decodes</code> method from it; \n",
"- label = the fully decoded prediction, using the transforms used to build the `Datasets`/`DataLoaders`."
]
},
{
Expand Down Expand Up @@ -3842,7 +3841,11 @@
"dec = 2*out #decodes from loss function\n",
"full_dec = dec-1 #decodes from _Add1\n",
"test_eq(learn.predict(inp), [full_dec,dec,out])\n",
"test_eq(learn.predict(inp, with_input=True), [inp,full_dec,dec,out])"
"test_eq(learn.predict(inp, with_input=True), [inp,full_dec,dec,out])\n",
"\n",
"# Example usage:\n",
"# label,ndx,probs = learner.predict(inp)\n",
"# print( f\"Predicted: {label} with probability {probs[ndx]}\" )"
]
},
{
Expand Down Expand Up @@ -4219,6 +4222,18 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 8c5074b

Please sign in to comment.