Skip to content

Commit

Permalink
Merge pull request #3964 from turbotimon/master
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Feb 29, 2024
2 parents 8e268e3 + f206928 commit 80d881e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
22 changes: 17 additions & 5 deletions fastai/learner.py
Expand Up @@ -594,13 +594,25 @@ def _valid_mets(self):
if getattr(self, 'cancel_valid', False): return L()
return (L(self.loss) + self.metrics if self.valid_metrics else L())

def plot_loss(self, skip_start=5, with_valid=True):
plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')
def plot_loss(self, skip_start=5, with_valid=True, log=False, show_epochs=False, ax=None):
if not ax:
ax=plt.gca()
if log:
ax.loglog(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')
else:
ax.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')
if show_epochs:
for x in self.iters:
ax.axvline(x, color='grey', ls=':')
ax.set_ylabel('loss')
ax.set_xlabel('steps')
ax.set_title('learning curve')
if with_valid:
idx = (np.array(self.iters)<skip_start).sum()
valid_col = self.metric_names.index('valid_loss') - 1
plt.plot(self.iters[idx:], L(self.values[idx:]).itemgot(valid_col), label='valid')
plt.legend()
ax.plot(self.iters[idx:], L(self.values[idx:]).itemgot(valid_col), label='valid')
ax.legend()
return ax

# %% ../nbs/13a_learner.ipynb 136
add_docs(Recorder,
Expand All @@ -610,7 +622,7 @@ def plot_loss(self, skip_start=5, with_valid=True):
after_validate = "Log loss and metric values on the validation set",
after_cancel_train = "Ignore training metrics for this epoch",
after_cancel_validate = "Ignore validation metrics for this epoch",
plot_loss = "Plot the losses from `skip_start` and onward")
plot_loss = "Plot the losses from `skip_start` and onward. Optionally `log=True` for logarithmic axis, `show_epochs=True` for indicate epochs and a matplotlib axis `ax` to plot on.")

if Recorder not in defaults.callbacks: defaults.callbacks.append(Recorder)

Expand Down
22 changes: 17 additions & 5 deletions nbs/13a_learner.ipynb
Expand Up @@ -2924,13 +2924,25 @@
" if getattr(self, 'cancel_valid', False): return L()\n",
" return (L(self.loss) + self.metrics if self.valid_metrics else L())\n",
"\n",
" def plot_loss(self, skip_start=5, with_valid=True):\n",
" plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')\n",
" def plot_loss(self, skip_start=5, with_valid=True, log=False, show_epochs=False, ax=None):\n",
" if not ax:\n",
" ax=plt.gca()\n",
" if log:\n",
" ax.loglog(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')\n",
" else:\n",
" ax.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train')\n",
" if show_epochs:\n",
" for x in self.iters:\n",
" ax.axvline(x, color='grey', ls=':')\n",
" ax.set_ylabel('loss')\n",
" ax.set_xlabel('steps')\n",
" ax.set_title('learning curve')\n",
" if with_valid:\n",
" idx = (np.array(self.iters)<skip_start).sum()\n",
" valid_col = self.metric_names.index('valid_loss') - 1 \n",
" plt.plot(self.iters[idx:], L(self.values[idx:]).itemgot(valid_col), label='valid')\n",
" plt.legend()"
" ax.plot(self.iters[idx:], L(self.values[idx:]).itemgot(valid_col), label='valid')\n",
" ax.legend()\n",
" return ax"
]
},
{
Expand All @@ -2947,7 +2959,7 @@
" after_validate = \"Log loss and metric values on the validation set\",\n",
" after_cancel_train = \"Ignore training metrics for this epoch\",\n",
" after_cancel_validate = \"Ignore validation metrics for this epoch\",\n",
" plot_loss = \"Plot the losses from `skip_start` and onward\")\n",
" plot_loss = \"Plot the losses from `skip_start` and onward. Optionally `log=True` for logarithmic axis, `show_epochs=True` for indicate epochs and a matplotlib axis `ax` to plot on.\")\n",
"\n",
"if Recorder not in defaults.callbacks: defaults.callbacks.append(Recorder)"
]
Expand Down

0 comments on commit 80d881e

Please sign in to comment.