Callbacks
The infer function and the underlying reactive message passing engine both have their own lifecycle, consisting of multiple steps. By supplying callbacks, users can inject custom logic at specific moments during the inference procedure — for example, for debugging, performance analysis, or early stopping.
Event-based callback system
All callbacks in RxInfer use an event-based dispatch system built on ReactiveMP.Event{E}. Each callback event is a concrete struct that carries all relevant data as named fields. This makes callbacks self-documenting and extensible.
For example, an AfterIterationEvent has fields model and iteration:
# NamedTuple callback — receives the event struct
callbacks = (
after_iteration = (event) -> println("Iteration ", event.iteration, " done"),
)Callback types
The callbacks keyword argument of the infer function accepts three types of callback handlers:
NamedTuple
The simplest way to define callbacks is via a NamedTuple, where keys correspond to event names and values are functions that receive a single event object:
using RxInfer
using RxInfer.ReactiveMP
@model function coin_model(y)
θ ~ Beta(1, 1)
y .~ Bernoulli(θ)
end
result = infer(
model = coin_model(),
data = (y = [1, 0, 1, 1, 0],),
callbacks = (
before_inference = (event) -> begin
println("Starting inference on model: ", typeof(event.model))
end,
after_inference = (event) -> begin
println("Inference completed")
end,
)
)Starting inference on model: ProbabilisticModel{GraphPPL.Model{MetaGraphsNext.MetaGraph{Int64, Graphs.SimpleGraphs.SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#MetaGraph##6#MetaGraph##7", Float64}, GraphPPL.PluginsCollection{Tuple{GraphPPL.VariationalConstraintsPlugin{GraphPPL.NoConstraints}, GraphPPL.MetaPlugin{GraphPPL.MetaSpecification}, RxInfer.InitializationPlugin{RxInfer.NoInit}, RxInfer.ReactiveMPInferencePlugin{RxInfer.ReactiveMPInferenceOptions{Nothing, Nothing, Nothing, @NamedTuple{before_inference::Main.var"Main".var"#14#15", after_inference::Main.var"Main".var"#16#17"}}}}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}, String}}
Inference completedWhen defining a NamedTuple with a single entry, make sure to include a trailing comma. In Julia, (key = value) is parsed as a variable assignment, not a NamedTuple. Use (key = value,) (with trailing comma) instead. ReactiveMP will raise a helpful error if this mistake is detected.
Dict
A Dict{Symbol} works the same way as a NamedTuple, but allows dynamic construction of callbacks:
my_callbacks = Dict(
:before_inference => (event) -> println("Starting inference"),
:after_inference => (event) -> println("Inference completed"),
)
result = infer(
model = coin_model(),
data = (y = [1, 0, 1, 1, 0],),
callbacks = my_callbacks,
)Starting inference
Inference completedCustom callback structures
For more advanced use cases, you can pass any custom structure as a callback handler. The structure must implement ReactiveMP.handle_event methods for the event types it wants to handle:
struct MyCallbackHandler
log::Vector{String}
end
# Catch-all: ignore events you don't care about
ReactiveMP.handle_event(::MyCallbackHandler, ::ReactiveMP.Event) = nothing
# Handle specific events by dispatching on the concrete event type
function ReactiveMP.handle_event(handler::MyCallbackHandler, event::BeforeInferenceEvent)
push!(handler.log, "inference started")
end
function ReactiveMP.handle_event(handler::MyCallbackHandler, event::AfterInferenceEvent)
push!(handler.log, "inference completed")
end
handler = MyCallbackHandler(String[])
result = infer(
model = coin_model(),
data = (y = [1, 0, 1, 1, 0],),
callbacks = handler,
)
println(handler.log)["inference started", "inference completed"]Dispatching by event name with Event{:name}
Since every event struct is a subtype of ReactiveMP.Event{E} where E is a Symbol, you can also dispatch on ReactiveMP.Event{:event_name} instead of the concrete type name. This is equivalent — you still have access to all the same fields — and can be more convenient when the concrete type is not exported. For example, ReactiveMP-level events like BeforeMessageRuleCallEvent are not exported by default, but you can always dispatch on ReactiveMP.Event{:before_message_rule_call}:
struct MyEventNameHandler
log::Vector{String}
end
# Catch-all
ReactiveMP.handle_event(::MyEventNameHandler, ::ReactiveMP.Event) = nothing
# Dispatch using Event{:name} — no need to know the concrete struct name
function ReactiveMP.handle_event(handler::MyEventNameHandler, event::ReactiveMP.Event{:before_inference})
push!(handler.log, "inference started (via Event name)")
end
function ReactiveMP.handle_event(handler::MyEventNameHandler, event::ReactiveMP.Event{:after_inference})
push!(handler.log, "inference completed (via Event name)")
end
handler_by_name = MyEventNameHandler(String[])
result = infer(
model = coin_model(),
data = (y = [1, 0, 1, 1, 0],),
callbacks = handler_by_name,
)
println(handler_by_name.log)["inference started (via Event name)", "inference completed (via Event name)"]Both approaches are fully interchangeable — use whichever is more convenient for your use case.
Custom callback structures are useful when you need to:
- Maintain state across events (e.g., collecting timing information)
- Implement complex logic that spans multiple events
- Store information in the model's
metadatadictionary for later access
RxInfer provides built-in callback structures such as RxInferBenchmarkCallbacks and StopEarlyIterationStrategy as examples of this pattern.
Model metadata
The ProbabilisticModel structure contains a metadata dictionary (Dict{Any, Any}) that callbacks can use to store arbitrary information during inference. This is accessible from the inference result via result.model.metadata.
For example, you can track the history of marginal updates during variational inference:
@model function gaussian_model(y)
μ ~ Normal(mean = 0.0, variance = 100.0)
τ ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = μ, precision = τ)
end
struct MarginalHistoryCollector end
ReactiveMP.handle_event(::MarginalHistoryCollector, ::ReactiveMP.Event) = nothing
function ReactiveMP.handle_event(::MarginalHistoryCollector, event::OnMarginalUpdateEvent)
history = get!(() -> [], event.model.metadata, :marginal_history)
push!(history, (iteration_variable = event.variable_name, value = event.update))
end
result = infer(
model = gaussian_model(),
data = (y = randn(50),),
constraints = MeanField(),
initialization = @initialization(begin
q(μ) = vague(NormalMeanVariance)
q(τ) = vague(GammaShapeRate)
end),
iterations = 5,
callbacks = MarginalHistoryCollector(),
)
# Access the collected marginal history
history = result.model.metadata[:marginal_history]
println("Collected ", length(history), " marginal updates across all iterations")
println("Variables updated: ", unique(map(h -> h.iteration_variable, history)))Collected 10 marginal updates across all iterations
Variables updated: [:τ, :μ]Available events
Callbacks can listen to events from two layers: RxInfer-level events from the inference lifecycle, and ReactiveMP-level events from the message passing engine itself.
RxInfer events
These events are fired by the infer function during the inference lifecycle. Each event is a concrete struct subtyping ReactiveMP.Event{E} with named fields.
Common to batch and streamline inference
RxInfer.BeforeModelCreationEvent — Type
BeforeModelCreationEvent{S} <: ReactiveMP.Event{:before_model_creation}Fires right before the probabilistic model is created in the infer function.
Fields
span_id: an identifier shared with the correspondingAfterModelCreationEvent
See also: AfterModelCreationEvent, Callbacks
RxInfer.AfterModelCreationEvent — Type
AfterModelCreationEvent{M, S} <: ReactiveMP.Event{:after_model_creation}Fires right after the probabilistic model is created in the infer function.
Fields
model::M: the createdProbabilisticModelinstancespan_id: an identifier shared with the correspondingBeforeModelCreationEvent
See also: BeforeModelCreationEvent, Callbacks
Batch inference only
For more details on batch inference, see Static inference.
RxInfer.OnMarginalUpdateEvent — Type
OnMarginalUpdateEvent{M, U} <: ReactiveMP.Event{:on_marginal_update}Fires each time a marginal posterior for a variable is updated during inference.
Fields
model::M: theProbabilisticModelinstancevariable_name::Symbol: the name of the variable whose marginal was updatedupdate::U: the updated marginal value
See also: Callbacks
RxInfer.BeforeInferenceEvent — Type
BeforeInferenceEvent{M, S} <: ReactiveMP.Event{:before_inference}Fires right before the inference procedure starts (after model creation and subscription setup).
Fields
model::M: theProbabilisticModelinstancespan_id: an identifier shared with the correspondingAfterInferenceEvent
See also: AfterInferenceEvent, Callbacks
RxInfer.AfterInferenceEvent — Type
AfterInferenceEvent{M, S} <: ReactiveMP.Event{:after_inference}Fires right after the inference procedure completes.
Fields
model::M: theProbabilisticModelinstancespan_id: an identifier shared with the correspondingBeforeInferenceEvent
See also: BeforeInferenceEvent, Callbacks
RxInfer.BeforeIterationEvent — Type
BeforeIterationEvent{M, S} <: ReactiveMP.Event{:before_iteration}Fires right before each variational iteration.
Fields
model::M: theProbabilisticModelinstanceiteration::Int: the current iteration numberstop_iteration::Bool: set totruefrom a callback to halt iterations early (default:false)span_id: an identifier shared with the correspondingAfterIterationEvent
See also: AfterIterationEvent, Callbacks
RxInfer.AfterIterationEvent — Type
AfterIterationEvent{M, S} <: ReactiveMP.Event{:after_iteration}Fires right after each variational iteration.
Fields
model::M: theProbabilisticModelinstanceiteration::Int: the current iteration numberstop_iteration::Bool: set totruefrom a callback to halt iterations early (default:false)span_id: an identifier shared with the correspondingBeforeIterationEvent
See also: BeforeIterationEvent, Callbacks
before_iteration and after_iteration events carry a mutable stop_iteration::Bool field (default false). Set event.stop_iteration = true from a callback to halt iterations early. See Early stopping for an example.
RxInfer.BeforeDataUpdateEvent — Type
BeforeDataUpdateEvent{M, D, S} <: ReactiveMP.Event{:before_data_update}Fires right before updating data variables in each iteration.
Fields
model::M: theProbabilisticModelinstancedata::D: the data being used for the updatespan_id: an identifier shared with the correspondingAfterDataUpdateEvent
See also: AfterDataUpdateEvent, Callbacks
RxInfer.AfterDataUpdateEvent — Type
AfterDataUpdateEvent{M, D, S} <: ReactiveMP.Event{:after_data_update}Fires right after updating data variables in each iteration.
Fields
model::M: theProbabilisticModelinstancedata::D: the data that was used for the updatespan_id: an identifier shared with the correspondingBeforeDataUpdateEvent
See also: BeforeDataUpdateEvent, Callbacks
Streamline inference only
For more details on streamline inference, see Streamline inference.
RxInfer.BeforeAutostartEvent — Type
BeforeAutostartEvent{E, S} <: ReactiveMP.Event{:before_autostart}Fires right before RxInfer.start() is called on the streaming inference engine (when autostart = true).
Fields
engine::E: theRxInferenceEngineinstancespan_id: an identifier shared with the correspondingAfterAutostartEvent
See also: AfterAutostartEvent, Callbacks
RxInfer.AfterAutostartEvent — Type
AfterAutostartEvent{E, S} <: ReactiveMP.Event{:after_autostart}Fires right after RxInfer.start() is called on the streaming inference engine (when autostart = true).
Fields
engine::E: theRxInferenceEngineinstancespan_id: an identifier shared with the correspondingBeforeAutostartEvent
See also: BeforeAutostartEvent, Callbacks
ReactiveMP events
These lower-level events are fired by the ReactiveMP message passing engine during inference. They are available in both batch and streamline inference modes. Each event is a concrete struct subtyping ReactiveMP.Event{E} with named fields — refer to the ReactiveMP documentation for field details.
BeforeMessageRuleCallEvent/AfterMessageRuleCallEvent— fired around message rule computationsBeforeProductOfTwoMessagesEvent/AfterProductOfTwoMessagesEvent— fired around pairwise message productsBeforeProductOfMessagesEvent/AfterProductOfMessagesEvent— fired around folded message productsBeforeFormConstraintAppliedEvent/AfterFormConstraintAppliedEvent— fired around form constraint applicationBeforeMarginalComputationEvent/AfterMarginalComputationEvent— fired around marginal computations
For detailed descriptions of these events and their fields, refer to the official documentation of ReactiveMP.
Migration from positional-argument callbacks
Previous versions of RxInfer passed callback arguments positionally:
# OLD (no longer works)
callbacks = (
after_iteration = (model, iteration) -> println(iteration),
on_marginal_update = (model, name, update) -> println(name),
)The new system uses event structs with named fields. Each callback now receives a single event object:
# NEW
callbacks = (
after_iteration = (event) -> println(event.iteration),
on_marginal_update = (event) -> println(event.variable_name),
)For custom callback structures, implement handle_event methods that dispatch on the concrete event type:
# Dispatch on specific events
ReactiveMP.handle_event(::MyHandler, event::BeforeInferenceEvent) = ... # event.model
# Catch-all for events you don't care about
ReactiveMP.handle_event(::MyHandler, ::ReactiveMP.Event) = nothingThe migration is straightforward: replace positional arguments with named field access on the event object. The event structs are fully documented with their field names — see the sections above.
Built-in callback handlers
RxInfer provides the following built-in callback structures:
RxInferBenchmarkCallbacks— collects timing statistics across inference runs. See Benchmark callbacks.RxInferTraceCallbacks— records all callback events for debugging and inspection. See Trace callbacks.StopEarlyIterationStrategy— stops variational iterations early based on free energy convergence. See Early stopping.