| Linux | Coverage |
|---|---|
A small package for applying early stopping criteria to loss-generating iterative algorithms, with a view to training and optimizing machine learning models.
The basis of IterationControl.jl, a package externally controlling iterative algorithms.
Includes the stopping criteria surveyed in Prechelt, Lutz (1998): "Early Stopping - But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer.
using Pkg
Pkg.add("EarlyStopping")The EarlyStopper objects defined in this package consume a sequence
of numbers called losses generated by some external algorithm -
generally the training loss or out-of-sample loss of some iterative
statistical model - and decide when those losses have dropped
sufficiently to warrant terminating the algorithm. A number of
commonly applied stopping criteria, listed under
Criteria below, are provided out-of-the-box.
Here's an example of using an EarlyStopper object to check against
two of these criteria (either triggering the stop):
using EarlyStopping
stopper = EarlyStopper(Patience(2), InvalidValue()) # multiple criteria
done!(stopper, 0.123) # false
done!(stopper, 0.234) # false
done!(stopper, 0.345) # true
julia> message(stopper)
"Early stop triggered by Patience(2) stopping criterion. "One may force an EarlyStopper to report its evolving state:
losses = [10.0, 11.0, 10.0, 11.0, 12.0, 10.0];
stopper = EarlyStopper(Patience(2), verbosity=1);
for loss in losses
done!(stopper, loss) && break
end[ Info: loss: 10.0 state: (loss = 10.0, n_increases = 0)
[ Info: loss: 11.0 state: (loss = 11.0, n_increases = 1)
[ Info: loss: 10.0 state: (loss = 10.0, n_increases = 0)
[ Info: loss: 11.0 state: (loss = 11.0, n_increases = 1)
[ Info: loss: 12.0 state: (loss = 12.0, n_increases = 2)
The "object-oriented" interface demonstrated here is not code-optimized but will suffice for the majority of use-cases. For performant code, use the functional interface described under Implementing new criteria below.
To list all stopping criterion, do subtypes(StoppingCriterion). Each
subtype T has a detailed doc-string queried with ?T at the
REPL. Here is a short summary:
| criterion | description | notation in Prechelt |
|---|---|---|
Never() |
Never stop | |
InvalidValue() |
Stop when NaN, Inf or -Inf encountered |
|
TimeLimit(t=0.5) |
Stop after t hours |
|
NumberLimit(n=100) |
Stop after n loss updates (excl. "training losses") |
|
NumberSinceBest(n=6) |
Stop after n loss updates (excl. "training losses") |
|
Threshold(value=0.0) |
Stop when loss < value |
|
GL(alpha=2.0) |
Stop after "Generalization Loss" exceeds alpha |
GL_α |
PQ(alpha=0.75, k=5) |
Stop after "Progress-modified GL" exceeds alpha |
PQ_α |
Patience(n=5) |
Stop after n consecutive loss increases |
UP_s |
Disjunction(c...) |
Stop when any of the criteria c apply |
|
Warmup(c; n=1) |
Wait for n loss updates before checking criteria c |
For criteria tracking both an "out-of-sample" loss and a "training"
loss (eg, stopping criterion of type PQ), specify training=true if
the update is for training, as in
done!(stopper, 0.123, training=true)In these cases, the out-of-sample update must always come after the corresponding training update. Multiple training updates may precede the out-of-sample update, as in the following example:
criterion = PQ(alpha=2.0, k=2)
needs_training_losses(criterion) # true
stopper = EarlyStopper(criterion)
done!(stopper, 9.5, training=true) # false
done!(stopper, 9.3, training=true) # false
done!(stopper, 10.0) # false
done!(stopper, 9.3, training=true) # false
done!(stopper, 9.1, training=true) # false
done!(stopper, 8.9, training=true) # false
done!(stopper, 8.0) # false
done!(stopper, 8.3, training=true) # false
done!(stopper, 8.4, training=true) # false
done!(stopper, 9.0) # trueImportant. If there is no distinction between in and out-of-sample
losses, then any criterion can be applied, and in that case training=true
is never specified (regardless of the actual interpretation of the
losses being tracked).
To determine the stopping time for an iterator losses, use
stopping_time(criterion, losses). This is useful for debugging new
criteria (see below). If the iterator terminates without a stop, 0
is returned.
julia> stopping_time(InvalidValue(), [10.0, 3.0, Inf, 4.0])
3
julia> stopping_time(Patience(3), [10.0, 3.0, 4.0, 5.0], verbosity=1)
[ Info: loss updates: 1
[ Info: state: (loss = 10.0, n_increases = 0)
[ Info: loss updates: 2
[ Info: state: (loss = 3.0, n_increases = 0)
[ Info: loss updates: 3
[ Info: state: (loss = 4.0, n_increases = 1)
[ Info: loss updates: 4
[ Info: state: (loss = 5.0, n_increases = 2)
0If the losses include both training and out-of-sample losses as
described above, pass an extra Bool vector marking the training
losses with true, as in
stopping_time(PQ(),
[0.123, 0.321, 0.52, 0.55, 0.56, 0.58],
[true, true, false, true, true, false])To implement a new stopping criterion, one must:
- Define a new
structfor the criterion, which must subtypeStoppingCriterion. - Overload methods
updateanddonefor the new type.
struct NewCriteria <: StoppingCriterion
# Put relevant fields here
end
# Provide a default constructor with all key-word arguments
NewCriteria(; kwargs...) = ...
# Return the initial state of the NewCriteria after
# receiving an out-of-sample loss
update(c::NewCriteria, loss, ::Nothing) = ...
# Return an updated state for NewCriteria given a `loss`
# and the current `state`
update(c::NewCriteria, loss, state) = ...
# Return true if NewCriteria should stop given `state`.
# Always return false if `state === nothing`
done(c::NewCriteria, state) = state === nothing ? false : ....Optionally, one may define the following:
- Overload the final message with
message. - Handle training losses by overloading
update_trainingand the traitneeds_training_losses.
# Final message when NewCriteria triggers a stop
message(c::NewCriteria, state) = ...
# Methods for initializing/updating the state given a training loss
update_training(c::NewCriteria, loss, ::Nothing) = ...
update_training(c::NewCriteria, loss, state) = ...Wrappers. If your criterion wraps another criterion (as Warmup
does) then the criterion must be a field and must store the
criterion being wrapped.
We demonstrate this with a simplified version of the
code for Patience:
using EarlyStopping
struct Patience <: StoppingCriterion
n::Int
end
Patience(; n=5) = Patience(n)All information to be "remembered" must passed around in an object
called state below, which is the return value of update (and
update_training). The update function has two methods:
- Initialization:
update(c::NewCriteria, loss, ::Nothing) - Subsequent Loss Updates:
update(c::NewCriteria, loss, state)
Where state is the return of the previous call to update or update_training.
Notice, that state === nothing indicates an uninitialized criteria.
import EarlyStopping: update, done
function update(criterion::Patience, loss, ::Nothing)
return (loss=loss, n_increases=0) # state
end
function update(criterion::Patience, loss, state)
old_loss, n = state
if loss > old_loss
n += 1
else
n = 0
end
return (loss=loss, n_increases=n) # state
endThe done method returns true or false depending on the state, but
always returns false if state === nothing.
done(criterion::Patience, state) =
state === nothing ? false : state.n_increases == criterion.n
The final message of an EarlyStopper is generated using a message
method for StoppingCriterion. Here is the fallback (which does not
use state):
EarlyStopping.message(criteria::StoppingCriterion, state)
= "Early stop triggered by $criterion stopping criterion. "The optional update_training methods (two for each criterion) have
the same signature as the update methods above. Refer to the PQ
code for an example.
If a stopping criterion requires one or more update_training calls
per update call to work, you should overload the trait
needs_training_losses for that type, as in this example from
the source code:
EarlyStopping.needs_training_losses(::Type{<:PQ}) = trueThe following are provided to facilitate testing of new criteria:
stopping_time: returns the stopping time for an iteratorlossesusingcriterion.@test_criteria NewCriteria(): Runs a suite of unit tests against the providedStoppingCriteria. This macro is only part of the test suite and is not part of the API.