Early stopping
RxInfer supports early stopping as an opt-in callback via StopEarlyIterationStrategy. For a general overview of the callbacks system, see the Callbacks section.
Constructors
StopEarlyIterationStrategy(rtol): setsatol = 0.0, uses the given relative tolerance.StopEarlyIterationStrategy(atol, rtol): sets both absolute and relative tolerances.- Both constructors use
start_fe_value = Inffor the initial comparison value.
RxInfer.StopEarlyIterationStrategy — Type
StopEarlyIterationStrategyEarly-stopping criterion based on consecutive Bethe free energy (FE) values.
Fields
atol::Float64: Absolute tolerance.rtol::Float64: Relative tolerance.start_fe_value::Float64: Initial FE reference used before the first iteration.fe_values::Vector{Float64}: History of observed FE values (most recent is last).
Constructors
StopEarlyIterationStrategy(rtol): usesatol = 0.0, customrtol.StopEarlyIterationStrategy(atol, rtol): custom absolute and relative tolerances.
Both constructors use start_fe_value = Inf by default to avoid immediate stopping on the first iteration.
RxInfer.StopEarlyIterationStrategy — Method
StopEarlyIterationStrategy(rtol::Real)Create an early-stopping strategy with atol = 0.0 and the given rtol. Uses start_fe_value = Inf by default.
RxInfer.StopEarlyIterationStrategy — Method
StopEarlyIterationStrategy(atol::Real, rtol::Real)Create an early-stopping strategy with explicit absolute (atol) and relative (rtol) tolerances. Uses start_fe_value = Inf by default.
Early stopping mechanism
The BeforeIterationEvent and AfterIterationEvent carry a mutable stop_iteration::Bool field (default false). Any callback can set event.stop_iteration = true to signal the inference engine to stop iterating. The StopEarlyIterationStrategy uses this mechanism internally — when the free energy has converged, it sets event.stop_iteration = true.
Check out more about callbacks for static inference here.
Note that in this case we still have to specify the iterations, which in the case of early stopping specifies maximum number of iterations.
Example
using RxInfer
@model function iid_normal(y)
m ~ Normal(mean = 0.0, variance = 1.0)
tau ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = m, precision = tau)
end
data = (y = randn(100),)
max_iterations = 50
initialization = @initialization begin
q(m) = NormalMeanVariance(0.0, 1.0)
q(tau) = GammaShapeRate(1.0, 1.0)
end
result = infer(
model = iid_normal(),
data = data,
constraints = MeanField(),
initialization = initialization,
free_energy = true,
iterations = max_iterations,
callbacks = (
after_iteration = StopEarlyIterationStrategy(1e-10, 1e-3),
)
)
length(result.free_energy)3As you can see the total number of free_energy evaluations is less than max_iterations.