Early stopping

RxInfer supports early stopping as an opt-in callback via StopEarlyIterationStrategy.

Constructors

  • StopEarlyIterationStrategy(rtol): sets atol = 0.0, uses the given relative tolerance.
  • StopEarlyIterationStrategy(atol, rtol): sets both absolute and relative tolerances.
  • Both constructors use start_fe_value = Inf for the initial comparison value.
RxInfer.StopEarlyIterationStrategyType
StopEarlyIterationStrategy

Early-stopping criterion based on consecutive Bethe free energy (FE) values.

Fields

  • atol::Float64: Absolute tolerance.
  • rtol::Float64: Relative tolerance.
  • start_fe_value::Float64: Initial FE reference used before the first iteration.
  • fe_values::Vector{Float64}: History of observed FE values (most recent is last).

Constructors

  • StopEarlyIterationStrategy(rtol): uses atol = 0.0, custom rtol.
  • StopEarlyIterationStrategy(atol, rtol): custom absolute and relative tolerances.

Both constructors use start_fe_value = Inf by default to avoid immediate stopping on the first iteration.

source
RxInfer.StopEarlyIterationStrategyMethod
StopEarlyIterationStrategy(atol::Real, rtol::Real)

Create an early-stopping strategy with explicit absolute (atol) and relative (rtol) tolerances. Uses start_fe_value = Inf by default.

source

Example

To use the RxInfer.StopEarlyIterationStrategy we need to pass it to the after_iteration field of the callbacks argument of the infer function. Check out more about callbacks for static inference here.

Note that in this case we still have to specify the iterations, which in the case of early stopping specifies maximum number of iterations.

using RxInfer

@model function iid_normal(y)
    m ~ Normal(mean = 0.0, variance = 1.0)
    tau ~ Gamma(shape = 1.0, rate = 1.0)
    y .~ Normal(mean = m, precision = tau)
end

data = (y = randn(100),)
max_iterations = 50
initialization = @initialization begin
    q(m) = NormalMeanVariance(0.0, 1.0)
    q(tau) = GammaShapeRate(1.0, 1.0)
end

result = infer(
    model = iid_normal(),
    data = data,
    constraints = MeanField(),
    initialization = initialization,
    free_energy = true,
    iterations = max_iterations,
    callbacks = (
        after_iteration = StopEarlyIterationStrategy(1e-10, 1e-3),
    )
)

length(result.free_energy)
3

As you can see the total number of free_energy evaluations is less than max_iterations.