Skip to content

Commit c698370

Browse files
authored
Merge pull request #778 from alan-turing-institute/dev
For a 0.16.2 release
2 parents 35219a9 + dbbd0ed commit c698370

8 files changed

+75
-57
lines changed

Project.toml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJ"
22
uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.16.1"
4+
version = "0.16.2"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -24,20 +24,20 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2424
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2525

2626
[compat]
27-
CategoricalArrays = "^0.8,^0.9"
28-
ComputationalResources = "^0.3"
29-
Distributions = "^0.21,^0.22,^0.23, 0.24"
30-
MLJBase = "^0.18"
31-
MLJIteration = "^0.2"
32-
MLJModels = "^0.14"
33-
MLJOpenML = "^1"
34-
MLJScientificTypes = "^0.4.1"
35-
MLJSerialization = "^1.1"
36-
MLJTuning = "^0.6"
37-
ProgressMeter = "^1.1"
38-
StatsBase = "^0.32,^0.33"
39-
Tables = "^0.2,^1.0"
40-
julia = "^1.1"
27+
CategoricalArrays = "0.8,0.9"
28+
ComputationalResources = "0.3"
29+
Distributions = "0.21,0.22,0.23, 0.24"
30+
MLJBase = "0.18"
31+
MLJIteration = "0.3"
32+
MLJModels = "0.14"
33+
MLJOpenML = "1"
34+
MLJScientificTypes = "0.4.1"
35+
MLJSerialization = "1.1"
36+
MLJTuning = "0.6"
37+
ProgressMeter = "1.1"
38+
StatsBase = "0.32,0.33"
39+
Tables = "0.2,1.0"
40+
julia = "1.1"
4141

4242
[extras]
4343
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"

docs/Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,11 @@ ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
3030
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
3131

3232
[compat]
33-
Documenter = "^0.26"
33+
Documenter = "0.26"
34+
MLJBase = "0.18"
35+
MLJIteration = "0.3"
36+
MLJModels = "0.14.4"
37+
MLJScientificTypes = "0.4.6"
38+
MLJTuning = "0.6.5"
39+
ScientificTypes = "1.1.1"
3440
julia = "1"

docs/src/controlling_iterative_models.md

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ iterations from the controlled training phase:
3737
```@example gree
3838
using MLJ
3939
40-
X, y = make_moons(1000, rng=123)
40+
X, y = make_moons(100, rng=123, noise=0.5)
4141
EvoTreeClassifier = @load EvoTreeClassifier verbosity=0
4242
4343
iterated_model = IteratedModel(model=EvoTreeClassifier(rng=123, η=0.005),
44-
resampling=Holdout(rng=123),
44+
resampling=Holdout(),
4545
measures=log_loss,
4646
controls=[Step(5),
4747
Patience(2),
@@ -92,26 +92,26 @@ the `IteratedModel` wrapper, but trained in each iteration on a subset
9292
of the data, according to the value of the `resampling`
9393
hyper-parameter of the wrapper.
9494

95-
control | description | can trigger a stop
96-
-----------------------------------------------------|-----------------------------------------------------------------------------------------|--------------------
97-
[`Step`](@ref IterationControl.Step)`(n=1)` | Train model for `n` more iterations | no
98-
[`TimeLimit`](@ref EarlyStopping.TimeLimit)`(t=0.5)` | Stop after `t` hours | yes
99-
[`NumberLimit`](@ref EarlyStopping.NumberLimit)`(n=100)` | Stop after `n` applications of the control | yes
100-
[`NumberSinceBest`](@ref EarlyStopping.NumberSinceBest)`(n=6)` | Stop when best loss occurred `n` control applications ago | yes
101-
[`NotANumber`](@ref EarlyStopping.NotANumber)`()` | Stop when `NaN` encountered | yes
102-
[`Threshold`](@ref EarlyStopping.Threshold)`(value=0.0)` | Stop when `loss < value` | yes
103-
[`GL`](@ref EarlyStopping.GL)`(alpha=2.0)` | † Stop after the "generalization loss (GL)" exceeds `alpha` | yes
104-
[`PQ`](@ref EarlyStopping.PQ)`(alpha=0.75, k=5)` | † Stop after "progress-modified GL" exceeds `alpha` | yes
105-
[`Patience`](@ref EarlyStopping.Patience)`(n=5)` | † Stop after `n` consecutive loss increases | yes
106-
[`Info`](@ref IterationControl.Info)`(f=identity)` | Log to `Info` the value of `f(mach)`, where `mach` is current machine | no
107-
[`Warn`](@ref IterationControl.Warn)`(predicate; f="")` | Log to `Warn` the value of `f` or `f(mach)`, if `predicate(mach)` holds | no
108-
[`Error`](@ref IterationControl.Error)`(predicate; f="")` | Log to `Error` the value of `f` or `f(mach)`, if `predicate(mach)` holds and then stop | yes
109-
[`Callback`](@ref IterationControl.Callback)`(f=mach->nothing)` | Call `f(mach)` | yes
110-
[`WithNumberDo`](@ref IterationControl.WithNumberDo)`(f=n->@info(n))` | Call `f(n + 1)` where `n` is number of previous calls | yes
111-
[`WithIterationsDo`](@ref MLJIteration.WithIterationsDo)`(f=i->@info("num iterations: $i"))` | Call `f(i)`, where `i` is total number of iterations | yes
112-
[`WithLossDo`](@ref IterationControl.WithLossDo)`(f=x->@info("loss: $x"))` | Call `f(loss)` where `loss` is the current loss | yes
113-
[`WithTrainingLossesDo`](@ref IterationControl.WithTrainingLossesDo)`(f=v->@info(v))` | Call `f(v)` where `v` is the current batch of training losses | yes
114-
[`Save`](@ref MLJSerialization.Save)`(filename="machine.jlso")` | * Save current machine to `machine1.jlso`, `machine2.jslo`, etc | yes
95+
control | description | can trigger a stop
96+
---------------------------------------------------------------|-----------------------------------------------------------------------------------------|--------------------
97+
[`Step`](@ref IterationControl.Step)`(n=1)` | Train model for `n` more iterations | no
98+
[`TimeLimit`](@ref EarlyStopping.TimeLimit)`(t=0.5)` | Stop after `t` hours | yes
99+
[`NumberLimit`](@ref EarlyStopping.NumberLimit)`(n=100)` | Stop after `n` applications of the control | yes
100+
[`NumberSinceBest`](@ref EarlyStopping.NumberSinceBest)`(n=6)` | Stop when best loss occurred `n` control applications ago | yes
101+
[`InvalidValue`](@ref IterationControl.InvalidValue)() | Stop when `NaN`, `Inf` or `-Inf` loss/training loss encountered | yes
102+
[`Threshold`](@ref EarlyStopping.Threshold)`(value=0.0)` | Stop when `loss < value` | yes
103+
[`GL`](@ref EarlyStopping.GL)`(alpha=2.0)` | † Stop after the "generalization loss (GL)" exceeds `alpha` | yes
104+
[`PQ`](@ref EarlyStopping.PQ)`(alpha=0.75, k=5)` | † Stop after "progress-modified GL" exceeds `alpha` | yes
105+
[`Patience`](@ref EarlyStopping.Patience)`(n=5)` | † Stop after `n` consecutive loss increases | yes
106+
[`Info`](@ref IterationControl.Info)`(f=identity)` | Log to `Info` the value of `f(mach)`, where `mach` is current machine | no
107+
[`Warn`](@ref IterationControl.Warn)`(predicate; f="")` | Log to `Warn` the value of `f` or `f(mach)`, if `predicate(mach)` holds | no
108+
[`Error`](@ref IterationControl.Error)`(predicate; f="")` | Log to `Error` the value of `f` or `f(mach)`, if `predicate(mach)` holds and then stop | yes
109+
[`Callback`](@ref IterationControl.Callback)`(f=mach->nothing)`| Call `f(mach)` | yes
110+
[`WithNumberDo`](@ref IterationControl.WithNumberDo)`(f=n->@info(n))` | Call `f(n + 1)` where `n` is the number of complete control cycles so far | yes
111+
[`WithIterationsDo`](@ref MLJIteration.WithIterationsDo)`(f=i->@info("num iterations: $i"))`| Call `f(i)`, where `i` is total number of iterations | yes
112+
[`WithLossDo`](@ref IterationControl.WithLossDo)`(f=x->@info("loss: $x"))` | Call `f(loss)` where `loss` is the current loss | yes
113+
[`WithTrainingLossesDo`](@ref IterationControl.WithTrainingLossesDo)`(f=v->@info(v))` | Call `f(v)` where `v` is the current batch of training losses | yes
114+
[`Save`](@ref MLJSerialization.Save)`(filename="machine.jlso")`| * Save current machine to `machine1.jlso`, `machine2.jslo`, etc | yes
115115

116116
> Table 1. Atomic controls. Some advanced options omitted.
117117
@@ -130,11 +130,12 @@ specified in the constructor: `Callback`, `WithNumberDo`,
130130

131131
There are also three control wrappers to modify a control's behavior:
132132

133-
wrapper | description
134-
---------------------------------------------------|-------------------------------------------------------------------------
135-
[`IterationControl.skip`](@ref)`(control, predicate=1)` | Apply `control` every `predicate` applications of the control wrapper (can also be a function; see doc-string)
136-
[`IterationControl.debug`](@ref)`(control)` | Apply `control` but also log its state to `Info` (irrespective of `verbosity` level)
137-
[`IterationControl.composite`](@ref)`(controls...)` | Apply each `control` in `controls` in sequence; used internally by IterationControl.jl
133+
wrapper | description
134+
---------------------------------------------------------------------------|-------------------------------------------------------------------------
135+
[`IterationControl.skip`](@ref)`(control, predicate=1)` | Apply `control` every `predicate` applications of the control wrapper (can also be a function; see doc-string)
136+
[`IterationControl.louder`](@ref IterationControl.louder)`(control, by=1)` | Increase the verbosity level of `control` by the specified value (negative values lower verbosity)
137+
[`IterationControl.debug`](@ref)`(control)` | Apply `control` but also log its state to `Info` (irrespective of `verbosity` level)
138+
[`IterationControl.composite`](@ref)`(controls...)` | Apply each `control` in `controls` in sequence; used internally by IterationControl.jl
138139

139140
> Table 2. Wrapped controls
140141
@@ -245,7 +246,8 @@ one might use `IterateFromList([round(Int, 10^x) for x in range(1, 2,
245246
length=10)]`.
246247

247248
In the code, `wrapper` is an object that wraps the training machine
248-
(see above).
249+
(see above). The variable `n` is a counter for control cycles (unused
250+
in this example).
249251

250252
```julia
251253

@@ -256,14 +258,14 @@ struct IterateFromList
256258
IterateFromList(v) = new(unique(sort(v)))
257259
end
258260

259-
function IterationControl.update!(control::IterateFromList, wrapper, verbosity)
261+
function IterationControl.update!(control::IterateFromList, wrapper, verbosity, n)
260262
Δi = control.list[1]
261263
verbosity > 1 && @info "Training $Δi more iterations. "
262264
MLJIteration.train!(wrapper, Δi) # trains the training machine
263265
return (index = 2, )
264266
end
265267

266-
function IterationControl.update!(control::IterateFromList, wrapper, verbosity, state)
268+
function IterationControl.update!(control::IterateFromList, wrapper, verbosity, n, state)
267269
index = state.positioin_in_list
268270
Δi = control.list[i] - wrapper.n_iterations
269271
verbosity > 1 && @info "Training $Δi more iterations. "
@@ -339,12 +341,14 @@ only a single control, called `control`, then training proceeds as
339341
follows:
340342

341343
```julia
342-
state = update!(control, wrapper, verbosity)
344+
n = 1 # initialize control cycle counter
345+
state = update!(control, wrapper, verbosity, n)
343346
finished = done(control, state)
344347

345348
# subsequent training events:
346349
while !finished
347-
state = update!(control, wrapper, verbosity, state)
350+
n += 1
351+
state = update!(control, wrapper, verbosity, n, state)
348352
finished = done(control, state)
349353
end
350354

@@ -386,14 +390,14 @@ end
386390
function IterationControl.update!(control::CycleLearningRate,
387391
wrapper,
388392
verbosity,
389-
state = (n = 0, learning_rates=nothing))
390-
n = state.n
393+
n,
394+
state = (learning_rates=nothing, ))
391395
rates = n == 0 ? one_cycle(control) : state.learning_rates
392396
index = mod(n, length(rates)) + 1
393397
r = rates[index]
394398
verbosity > 1 && @info "learning rate: $r"
395399
wrapper.model.iteration_control = r
396-
return (n = n + 1, learning_rates = rates)
400+
return (learning_rates = rates,)
397401
end
398402
```
399403

@@ -411,7 +415,7 @@ IterationControl.Step
411415
EarlyStopping.TimeLimit
412416
EarlyStopping.NumberLimit
413417
EarlyStopping.NumberSinceBest
414-
EarlyStopping.NotANumber
418+
EarlyStopping.InvalidValue
415419
EarlyStopping.Threshold
416420
EarlyStopping.GL
417421
EarlyStopping.PQ
@@ -431,6 +435,7 @@ MLJSerialization.Save
431435

432436
```@docs
433437
IterationControl.skip
438+
IterationControl.louder
434439
IterationControl.debug
435440
IterationControl.composite
436441
```
60.2 KB
Loading
16.9 KB
Loading

docs/src/img/tuning_plot.png

10.5 KB
Loading

docs/src/index.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
```@raw html
2+
<script async defer src="https://buttons.github.io/buttons.js"></script>
3+
24
<span style="color:darkslateblue;font-size:2.25em;font-style:italic;">
35
A Machine Learning Framework for Julia
4-
</span>
6+
</span> &nbsp; &nbsp; &nbsp; &nbsp;
7+
<a class="github-button" href="https://github.com/alan-turing-institute/MLJ.jl" data-icon="octicon-star" data-size="large" data-show-count="true" aria-label="Star alan-turing-institute/MLJ.jl on GitHub">Star</a>
8+
59
<br>
610
<br>
711
<div style="font-size:1.25em;font-weight:bold;">

docs/src/tuning_models.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ tuning strategies. Also available is the [tree
66
Parzen](https://github.com/IQVIA-ML/TreeParzen.jl) strategy; for a
77
complete list, see
88
[here](https://github.com/alan-turing-institute/MLJTuning.jl#what-is-provided-here).
9-
9+
1010
MLJ tuning is implemented as an *iterative* procedure, which can
1111
accordingly be controlled using MLJ's [`IteratedModel`](@ref
1212
MLJIteration.IteratedModel) wrapper. After familiarizing one self with
@@ -129,7 +129,7 @@ info("KNNClassifier").prediction_type
129129
```
130130

131131
```@example goof
132-
X, y = @load_iris
132+
X, y = @load_iris
133133
KNN = @load KNNClassifier verbosity=0
134134
knn = KNN()
135135
```
@@ -267,7 +267,7 @@ mach = machine(self_tuning_forest, X, y);
267267
fit!(mach, verbosity=0);
268268
```
269269

270-
In this two-parameter case, a plot of the grid search results is also
270+
We can plot the grid search results is also
271271
available:
272272

273273
```julia
@@ -355,11 +355,14 @@ confused with the iteration parameter `n` in the construction of a
355355
corresponding `TunedModel` instance, which specifies the total number
356356
of models to be evaluated, independent of the tuning strategy.
357357

358+
For this illustration we'll add a third, nominal, hyper-parameter:
359+
358360
```@example goof
361+
r3 = range(forest, :(atom.post_prune), values=[true, false]);
359362
self_tuning_forest = TunedModel(model=forest,
360363
tuning=latin,
361364
resampling=CV(nfolds=6),
362-
range=[r1, r2],
365+
range=[r1, r2, r3],
363366
measure=rms,
364367
n=25);
365368
mach = machine(self_tuning_forest, X, y);

0 commit comments

Comments
 (0)