Early stopping
RxInfer supports early stopping as an opt-in callback via StopEarlyIterationStrategy.
Constructors
StopEarlyIterationStrategy(rtol): setsatol = 0.0, uses the given relative tolerance.StopEarlyIterationStrategy(atol, rtol): sets both absolute and relative tolerances.- Both constructors use
start_fe_value = Inffor the initial comparison value.
RxInfer.StopEarlyIterationStrategy — Type
StopEarlyIterationStrategyEarly-stopping criterion based on consecutive Bethe free energy (FE) values.
Fields
atol::Float64: Absolute tolerance.rtol::Float64: Relative tolerance.start_fe_value::Float64: Initial FE reference used before the first iteration.fe_values::Vector{Float64}: History of observed FE values (most recent is last).
Constructors
StopEarlyIterationStrategy(rtol): usesatol = 0.0, customrtol.StopEarlyIterationStrategy(atol, rtol): custom absolute and relative tolerances.
Both constructors use start_fe_value = Inf by default to avoid immediate stopping on the first iteration.
RxInfer.StopEarlyIterationStrategy — Method
StopEarlyIterationStrategy(rtol::Real)Create an early-stopping strategy with atol = 0.0 and the given rtol. Uses start_fe_value = Inf by default.
RxInfer.StopEarlyIterationStrategy — Method
StopEarlyIterationStrategy(atol::Real, rtol::Real)Create an early-stopping strategy with explicit absolute (atol) and relative (rtol) tolerances. Uses start_fe_value = Inf by default.
Example
To use the RxInfer.StopEarlyIterationStrategy we need to pass it to the after_iteration field of the callbacks argument of the infer function. Check out more about callbacks for static inference here.
Note that in this case we still have to specify the iterations, which in the case of early stopping specifies maximum number of iterations.
using RxInfer
@model function iid_normal(y)
m ~ Normal(mean = 0.0, variance = 1.0)
tau ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = m, precision = tau)
end
data = (y = randn(100),)
max_iterations = 50
initialization = @initialization begin
q(m) = NormalMeanVariance(0.0, 1.0)
q(tau) = GammaShapeRate(1.0, 1.0)
end
result = infer(
model = iid_normal(),
data = data,
constraints = MeanField(),
initialization = initialization,
free_energy = true,
iterations = max_iterations,
callbacks = (
after_iteration = StopEarlyIterationStrategy(1e-10, 1e-3),
)
)
length(result.free_energy)3As you can see the total number of free_energy evaluations is less than max_iterations.