Callbacks

The infer function and the underlying reactive message passing engine both have their own lifecycle, consisting of multiple steps. By supplying callbacks, users can inject custom logic at specific moments during the inference procedure — for example, for debugging, performance analysis, or early stopping.

Event-based callback system

All callbacks in RxInfer use an event-based dispatch system built on ReactiveMP.Event{E}. Each callback event is a concrete struct that carries all relevant data as named fields. This makes callbacks self-documenting and extensible.

For example, an AfterIterationEvent has fields model and iteration:

# NamedTuple callback — receives the event struct
callbacks = (
    after_iteration = (event) -> println("Iteration ", event.iteration, " done"),
)

Callback types

The callbacks keyword argument of the infer function accepts three types of callback handlers:

NamedTuple

The simplest way to define callbacks is via a NamedTuple, where keys correspond to event names and values are functions that receive a single event object:

using RxInfer
using RxInfer.ReactiveMP

@model function coin_model(y)
    θ ~ Beta(1, 1)
    y .~ Bernoulli(θ)
end


result = infer(
    model = coin_model(),
    data  = (y = [1, 0, 1, 1, 0],),
    callbacks = (
        before_inference   = (event) -> begin
            println("Starting inference on model: ", typeof(event.model))
        end,
        after_inference    = (event) -> begin
            println("Inference completed")
        end,
    )
)
Starting inference on model: ProbabilisticModel{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#MetaGraph##6#MetaGraph##7", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{Nothing, Nothing, Nothing, @NamedTuple{before_inference::Main.var"Main".var"#14#15", after_inference::Main.var"Main".var"#16#17"}}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}}
Inference completed
Warning

When defining a NamedTuple with a single entry, make sure to include a trailing comma. In Julia, (key = value) is parsed as a variable assignment, not a NamedTuple. Use (key = value,) (with trailing comma) instead. ReactiveMP will raise a helpful error if this mistake is detected.

Dict

A Dict{Symbol} works the same way as a NamedTuple, but allows dynamic construction of callbacks:

my_callbacks = Dict(
    :before_inference => (event) -> println("Starting inference"),
    :after_inference  => (event) -> println("Inference completed"),
)

result = infer(
    model = coin_model(),
    data  = (y = [1, 0, 1, 1, 0],),
    callbacks = my_callbacks,
)
Starting inference
Inference completed

Custom callback structures

For more advanced use cases, you can pass any custom structure as a callback handler. The structure must implement ReactiveMP.handle_event methods for the event types it wants to handle:

struct MyCallbackHandler
    log::Vector{String}
end

# Catch-all: ignore events you don't care about
ReactiveMP.handle_event(::MyCallbackHandler, ::ReactiveMP.Event) = nothing

# Handle specific events by dispatching on the concrete event type
function ReactiveMP.handle_event(handler::MyCallbackHandler, event::BeforeInferenceEvent)
    push!(handler.log, "inference started")
end

function ReactiveMP.handle_event(handler::MyCallbackHandler, event::AfterInferenceEvent)
    push!(handler.log, "inference completed")
end

handler = MyCallbackHandler(String[])

result = infer(
    model = coin_model(),
    data  = (y = [1, 0, 1, 1, 0],),
    callbacks = handler,
)

println(handler.log)
["inference started", "inference completed"]

Dispatching by event name with Event{:name}

Since every event struct is a subtype of ReactiveMP.Event{E} where E is a Symbol, you can also dispatch on ReactiveMP.Event{:event_name} instead of the concrete type name. This is equivalent — you still have access to all the same fields — and can be more convenient when the concrete type is not exported. For example, ReactiveMP-level events like BeforeMessageRuleCallEvent are not exported by default, but you can always dispatch on ReactiveMP.Event{:before_message_rule_call}:

struct MyEventNameHandler
    log::Vector{String}
end

# Catch-all
ReactiveMP.handle_event(::MyEventNameHandler, ::ReactiveMP.Event) = nothing

# Dispatch using Event{:name} — no need to know the concrete struct name
function ReactiveMP.handle_event(handler::MyEventNameHandler, event::ReactiveMP.Event{:before_inference})
    push!(handler.log, "inference started (via Event name)")
end

function ReactiveMP.handle_event(handler::MyEventNameHandler, event::ReactiveMP.Event{:after_inference})
    push!(handler.log, "inference completed (via Event name)")
end

handler_by_name = MyEventNameHandler(String[])

result = infer(
    model = coin_model(),
    data  = (y = [1, 0, 1, 1, 0],),
    callbacks = handler_by_name,
)

println(handler_by_name.log)
["inference started (via Event name)", "inference completed (via Event name)"]

Both approaches are fully interchangeable — use whichever is more convenient for your use case.

Custom callback structures are useful when you need to:

  • Maintain state across events (e.g., collecting timing information)
  • Implement complex logic that spans multiple events
  • Store information in the model's metadata dictionary for later access

RxInfer provides built-in callback structures such as RxInferBenchmarkCallbacks and StopEarlyIterationStrategy as examples of this pattern.

Model metadata

The ProbabilisticModel structure contains a metadata dictionary (Dict{Any, Any}) that callbacks can use to store arbitrary information during inference. This is accessible from the inference result via result.model.metadata.

For example, you can track the history of marginal updates during variational inference:

@model function gaussian_model(y)
    μ ~ Normal(mean = 0.0, variance = 100.0)
    τ ~ Gamma(shape = 1.0, rate = 1.0)
    y .~ Normal(mean = μ, precision = τ)
end

struct MarginalHistoryCollector end

ReactiveMP.handle_event(::MarginalHistoryCollector, ::ReactiveMP.Event) = nothing

function ReactiveMP.handle_event(::MarginalHistoryCollector, event::OnMarginalUpdateEvent)
    history = get!(() -> [], event.model.metadata, :marginal_history)
    push!(history, (iteration_variable = event.variable_name, value = event.update))
end

result = infer(
    model = gaussian_model(),
    data  = (y = randn(50),),
    constraints = MeanField(),
    initialization = @initialization(begin
        q(μ) = vague(NormalMeanVariance)
        q(τ) = vague(GammaShapeRate)
    end),
    iterations = 5,
    callbacks = MarginalHistoryCollector(),
)


# Access the collected marginal history
history = result.model.metadata[:marginal_history]
println("Collected ", length(history), " marginal updates across all iterations")
println("Variables updated: ", unique(map(h -> h.iteration_variable, history)))
Collected 10 marginal updates across all iterations
Variables updated: [:τ, :μ]

Available events

Callbacks can listen to events from two layers: RxInfer-level events from the inference lifecycle, and ReactiveMP-level events from the message passing engine itself.

RxInfer events

These events are fired by the infer function during the inference lifecycle. Each event is a concrete struct subtyping ReactiveMP.Event{E} with named fields.

Common to batch and streamline inference

Batch inference only

For more details on batch inference, see Static inference.

RxInfer.OnMarginalUpdateEventType
OnMarginalUpdateEvent{M, U} <: ReactiveMP.Event{:on_marginal_update}

Fires each time a marginal posterior for a variable is updated during inference.

Fields

  • model::M: the ProbabilisticModel instance
  • variable_name::Symbol: the name of the variable whose marginal was updated
  • update::U: the updated marginal value

See also: Callbacks

source
Note

before_iteration and after_iteration events carry a mutable stop_iteration::Bool field (default false). Set event.stop_iteration = true from a callback to halt iterations early. See Early stopping for an example.

Streamline inference only

For more details on streamline inference, see Streamline inference.

ReactiveMP events

These lower-level events are fired by the ReactiveMP message passing engine during inference. They are available in both batch and streamline inference modes. Each event is a concrete struct subtyping ReactiveMP.Event{E} with named fields — refer to the ReactiveMP documentation for field details.

  • BeforeMessageRuleCallEvent / AfterMessageRuleCallEvent — fired around message rule computations
  • BeforeProductOfTwoMessagesEvent / AfterProductOfTwoMessagesEvent — fired around pairwise message products
  • BeforeProductOfMessagesEvent / AfterProductOfMessagesEvent — fired around folded message products
  • BeforeFormConstraintAppliedEvent / AfterFormConstraintAppliedEvent — fired around form constraint application
  • BeforeMarginalComputationEvent / AfterMarginalComputationEvent — fired around marginal computations

For detailed descriptions of these events and their fields, refer to the official documentation of ReactiveMP.

Migration from positional-argument callbacks

Breaking change

Previous versions of RxInfer passed callback arguments positionally:

# OLD (no longer works)
callbacks = (
    after_iteration = (model, iteration) -> println(iteration),
    on_marginal_update = (model, name, update) -> println(name),
)

The new system uses event structs with named fields. Each callback now receives a single event object:

# NEW
callbacks = (
    after_iteration = (event) -> println(event.iteration),
    on_marginal_update = (event) -> println(event.variable_name),
)

For custom callback structures, implement handle_event methods that dispatch on the concrete event type:

# Dispatch on specific events
ReactiveMP.handle_event(::MyHandler, event::BeforeInferenceEvent) = ... # event.model
# Catch-all for events you don't care about
ReactiveMP.handle_event(::MyHandler, ::ReactiveMP.Event) = nothing

The migration is straightforward: replace positional arguments with named field access on the event object. The event structs are fully documented with their field names — see the sections above.

Built-in callback handlers

RxInfer provides the following built-in callback structures: