Early stopping

RxInfer supports early stopping as an opt-in callback via StopEarlyIterationStrategy. For a general overview of the callbacks system, see the Callbacks section.

Constructors

  • StopEarlyIterationStrategy(rtol): sets atol = 0.0, uses the given relative tolerance.
  • StopEarlyIterationStrategy(atol, rtol): sets both absolute and relative tolerances.
  • Both constructors use start_fe_value = Inf for the initial comparison value.
RxInfer.StopEarlyIterationStrategyType
StopEarlyIterationStrategy

Early-stopping criterion based on consecutive Bethe free energy (FE) values.

Fields

  • atol::Float64: Absolute tolerance.
  • rtol::Float64: Relative tolerance.
  • start_fe_value::Float64: Initial FE reference used before the first iteration.
  • fe_values::Vector{Float64}: History of observed FE values (most recent is last).

Constructors

  • StopEarlyIterationStrategy(rtol): uses atol = 0.0, custom rtol.
  • StopEarlyIterationStrategy(atol, rtol): custom absolute and relative tolerances.

Both constructors use start_fe_value = Inf by default to avoid immediate stopping on the first iteration.

source
RxInfer.StopEarlyIterationStrategyMethod
StopEarlyIterationStrategy(atol::Real, rtol::Real)

Create an early-stopping strategy with explicit absolute (atol) and relative (rtol) tolerances. Uses start_fe_value = Inf by default.

source

Early stopping mechanism

The BeforeIterationEvent and AfterIterationEvent carry a mutable stop_iteration::Bool field (default false). Any callback can set event.stop_iteration = true to signal the inference engine to stop iterating. The StopEarlyIterationStrategy uses this mechanism internally — when the free energy has converged, it sets event.stop_iteration = true.

Check out more about callbacks for static inference here.

Note that in this case we still have to specify the iterations, which in the case of early stopping specifies maximum number of iterations.

Example

using RxInfer

@model function iid_normal(y)
    m ~ Normal(mean = 0.0, variance = 1.0)
    tau ~ Gamma(shape = 1.0, rate = 1.0)
    y .~ Normal(mean = m, precision = tau)
end

data = (y = randn(100),)
max_iterations = 50
initialization = @initialization begin
    q(m) = NormalMeanVariance(0.0, 1.0)
    q(tau) = GammaShapeRate(1.0, 1.0)
end

result = infer(
    model = iid_normal(),
    data = data,
    constraints = MeanField(),
    initialization = initialization,
    free_energy = true,
    iterations = max_iterations,
    callbacks = (
        after_iteration = StopEarlyIterationStrategy(1e-10, 1e-3),
    )
)

length(result.free_energy)
3

As you can see the total number of free_energy evaluations is less than max_iterations.