Debugging
Debugging inference in RxInfer can be quite challenging, mostly due to the reactive nature of the inference, undefined order of computations, the use of observables, and generally hard-to-read stack traces in Julia. Below we discuss ways to help you find problems in your model that prevents you from getting the results you want.
Getting Help from the Community
When you encounter issues that are difficult to debug, the RxInfer community is here to help. To get the most effective support:
Share Session Data: For complex issues, you can share your session data to help us understand exactly what's happening in your model. See Session Sharing to learn how.
Join Community Meetings: We discuss common issues and solutions in our regular community meetings. See Getting Help with Issues for more information.
Requesting a trace of messages
RxInfer provides a way that allows to save the history of the computations leading up to the computed messages and marginals in the inference procedure. This history is added on top of messages and marginals and is referred to as a Memory Addon. Below is an example explaining how you can extract this history and use it to fix a bug.
Addons is a feature of ReactiveMP. Read more about implementing custom addons in the corresponding section of ReactiveMP package.
We show the application of the Memory Addon on the coin toss example from earlier in the documentation. We model the binary outcome $x$ (heads or tails) using a Bernoulli distribution, with a parameter $\theta$ that represents the probability of landing on heads. We have a Beta prior distribution for the $\theta$ parameter, with a known shape $\alpha$ and rate $\beta$ parameter.
\[\theta \sim \mathrm{Beta}(a, b)\]
\[x_i \sim \mathrm{Bernoulli}(\theta)\]
where $x_i \in {0, 1}$ are the binary observations (heads = 1, tails = 0). This is the corresponding RxInfer model:
using RxInfer, Random, Plots
n = 4
θ_real = 0.3
dataset = float.(rand(Bernoulli(θ_real), n))
@model function coin_model(x)
θ ~ Beta(4, huge)
x .~ Bernoulli(θ)
end
result = infer(
model = coin_model(),
data = (x = dataset, ),
)Inference results:
Posteriors | available for (θ)
The model will run without errors. But when we plot the posterior distribution for $\theta$, something's wrong. The posterior seems to be a flat distribution:
rθ = range(0, 1, length = 1000)
plot(rθ, (rvar) -> pdf(result.posteriors[:θ], rvar), label="Infered posterior")
vline!([θ_real], label="Real θ", title = "Inference results")We can figure out what's wrong by tracing the computation of the posterior with the Memory Addon. To obtain the trace, we have to add addons = (AddonMemory(),) as an argument to the inference function. Note, that the argument to the addons keyword argument must be a tuple, because multiple addons can be activated at the same time. Here, we create a tuple with a single element however.
result = infer(
model = coin_model(),
data = (x = dataset, ),
addons = (AddonMemory(),)
)Inference results:
Posteriors | available for (θ)
Now we have access to the messages that led to the marginal posterior:
RxInfer.ReactiveMP.getaddons(result.posteriors[:θ])(AddonMemory(Product memory:
Message mapping memory:
At the node: Beta
Towards interface: Val{:out}()
With local constraint: Marginalisation()
With addons: (AddonMemory(nothing),)
With input marginals on Val{(:a, :b)}() edges: (PointMass{Int64}(4), PointMass{TinyHugeNumbers.HugeNumber}(huge))
With the result: Beta{Float64}(α=4.0, β=1.0e12)
Message mapping memory:
At the node: Bernoulli
Towards interface: Val{:p}()
With local constraint: Marginalisation()
With addons: (AddonMemory(nothing),)
With input marginals on Val{(:out,)}() edges: (PointMass{Float64}(1.0),)
With the result: Beta{Float64}(α=2.0, β=1.0)
Message mapping memory:
At the node: Bernoulli
Towards interface: Val{:p}()
With local constraint: Marginalisation()
With addons: (AddonMemory(nothing),)
With input marginals on Val{(:out,)}() edges: (PointMass{Float64}(0.0),)
With the result: Beta{Float64}(α=1.0, β=2.0)
Message mapping memory:
At the node: Bernoulli
Towards interface: Val{:p}()
With local constraint: Marginalisation()
With addons: (AddonMemory(nothing),)
With input marginals on Val{(:out,)}() edges: (PointMass{Float64}(0.0),)
With the result: Beta{Float64}(α=1.0, β=2.0)
Message mapping memory:
At the node: Bernoulli
Towards interface: Val{:p}()
With local constraint: Marginalisation()
With addons: (AddonMemory(nothing),)
With input marginals on Val{(:out,)}() edges: (PointMass{Float64}(0.0),)
With the result: Beta{Float64}(α=1.0, β=2.0)
),)
The messages in the factor graph are marked in color. If you're interested in the mathematics behind these results, consider verifying them manually using the general equation for sum-product messages:
\[\underbrace{\overrightarrow{\mu}_{θ}(θ)}_{\substack{ \text{outgoing}\\ \text{message}}} = \sum_{x_1,\ldots,x_n} \underbrace{\overrightarrow{\mu}_{X_1}(x_1)\cdots \overrightarrow{\mu}_{X_n}(x_n)}_{\substack{\text{incoming} \\ \text{messages}}} \cdot \underbrace{f(θ,x_1,\ldots,x_n)}_{\substack{\text{node}\\ \text{function}}}\]

Note that the posterior (yellow) has a rate parameter on the order of 1e12. Our plot failed because a Beta distribution with such a rate parameter cannot be accurately depicted using the range of $\theta$ we used in the code block above. So why does the posterior have this rate parameter?
All the observations (purple, green, pink, blue) have much smaller rate parameters. It seems the prior distribution (red) has an unusual rate parameter, namely 1e12. If we look back at the model, the parameter was set to huge (which is a reserved keyword meaning 1e12). Reducing the prior rate parameter will ensure the posterior has a reasonable rate parameter as well.
@model function coin_model(x)
θ ~ Beta(4, 100)
x .~ Bernoulli(θ)
end
result = infer(
model = coin_model(),
data = (x = dataset, ),
)Inference results:
Posteriors | available for (θ)
rθ = range(0, 1, length = 1000)
plot(rθ, (rvar) -> pdf(result.posteriors[:θ], rvar), fillalpha = 0.4, fill = 0, label="Infered posterior")
vline!([θ_real], label="Real θ", title = "Inference results")Now the posterior has much more sensible shape thus confirming that we have identified the original issue correctly. We can run the model with more observations, to get an even better posterior:
result = infer(
model = coin_model(),
data = (x = float.(rand(Bernoulli(θ_real), 1000)), ),
)
rθ = range(0, 1, length = 1000)
plot(rθ, (rvar) -> pdf(result.posteriors[:θ], rvar), fillalpha = 0.4, fill = 0, label="Infered posterior (1000 observations)")
vline!([θ_real], label="Real θ", title = "Inference results")Using callbacks in the infer function
Another way to inspect the inference procedure is to use the callbacks or events from the infer function. Read more about callbacks in the documentation to the infer function. Here, we show an application of callbacks to a model containing both scalar and vectorized latent variables. We start with the model specification:
using RxInfer
@model function vectorized_model(y)
local u
θ ~ Normal(mean= 1.0, var = 1.0)
for i in 1:2
u[i] ~ Normal(mean = θ, var = 1.0) # Vectorized variables
end
y .~ Normal(mean = u, var = 1.0)
endNext, let us define a syntehtic dataset:
dataset = [1.0, 2.0]Now, we can use the callbacks argument of the infer function to track the order of posteriors computation and their intermediate values for each variational iteration.
Since RxInfer passes a collection of posteriors when updating vectorized variables, we define two methods for on_marginal_update_callback:
- Scalar method: For individual variables (e.g.,
θ). - Vector method: Specifically dispatched on
AbstractVectorto handle vectorized updates (e.g.,u) using broadcasted operations likemean.().
# A callback that will be called every time before a variational iteration starts
function before_iteration_callback(model, iteration)
println("--- Starting iteration ", iteration, " ---")
end
# A callback that will be called every time after a variational iteration finishes
function after_iteration_callback(model, iteration)
println("--- Iteration ", iteration, " has been finished ---")
end
# A callback that will be called every time a posterior is updated
function on_marginal_update_callback(model, variable_name, posterior)
println("Latent variable ", variable_name, " has been updated. Estimated mean is ", mean(posterior), " with standard deviation ", std(posterior))
end
# A callback dispatched specifically for vectorized latent variables
function on_marginal_update_callback(model, variable_name, posteriors::AbstractVector)
println("Latent variable ", variable_name, " has been updated. Estimated mean is ", mean.(posteriors), " with standard deviation ", std.(posteriors))
endon_marginal_update_callback (generic function with 2 methods)After we have defined all callbacks of interest, we can call the infer function passing them in the callback argument as a named tuple:
result = infer(
model = vectorized_model(),
data = (y = dataset, ),
iterations = 3,
initialization = @initialization(q(θ) = Uniform(0, 1)),
returnvars = KeepLast(),
callbacks = (
on_marginal_update = on_marginal_update_callback,
before_iteration = before_iteration_callback,
after_iteration = after_iteration_callback
)
)--- Starting iteration 1 ---
Latent variable θ has been updated. Estimated mean is 1.25 with standard deviation 0.7071067811865476
Latent variable u has been updated. Estimated mean is [1.125, 1.625] with standard deviation [0.7905694150420949, 0.7905694150420949]
--- Iteration 1 has been finished ---
--- Starting iteration 2 ---
Latent variable θ has been updated. Estimated mean is 1.25 with standard deviation 0.7071067811865476
Latent variable u has been updated. Estimated mean is [1.125, 1.625] with standard deviation [0.7905694150420949, 0.7905694150420949]
--- Iteration 2 has been finished ---
--- Starting iteration 3 ---
Latent variable θ has been updated. Estimated mean is 1.25 with standard deviation 0.7071067811865476
Latent variable u has been updated. Estimated mean is [1.125, 1.625] with standard deviation [0.7905694150420949, 0.7905694150420949]
--- Iteration 3 has been finished ---We can see that the callback has been correctly executed for each intermediate variational iteration, correctly handling both the scalar θ and the vector u.
println("Estimated mean u[1]: ", mean(result.posteriors[:u][1]))
println("Estimated mean u[2]: ", mean(result.posteriors[:u][2]))
println("Estimated mean θ: ", mean(result.posteriors[:θ]))Estimated mean u[1]: 1.125
Estimated mean u[2]: 1.625
Estimated mean θ: 1.25Using LoggerPipelineStage
ReactiveMP inference engine allows attaching extra computations to the default computational pipeline of message passing. Read more about pipelines in the corresponding section of ReactiveMP. Here we show how to use LoggerPipelineStage to trace the order of message passing updates for debugging purposes. We start with model specification:
using RxInfer
@model function iid_normal_with_pipeline(y)
μ ~ Normal(mean = 0.0, variance = 100.0)
γ ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = μ, precision = γ) where { pipeline = LoggerPipelineStage() }
endNext, let us define a syntehtic dataset:
# We use less data points in the dataset to reduce the amount of text printed
# during the inference
dataset = rand(NormalMeanPrecision(3.1415, 30.0), 5)Now, we can call the infer function. We combine the pipeline logger stage with the callbacks, which were introduced in the previous section:
init = @initialization begin
q(μ) = vague(NormalMeanVariance)
end
result = infer(
model = iid_normal_with_pipeline(),
data = (y = dataset, ),
constraints = MeanField(),
iterations = 5,
initialization = init,
returnvars = KeepLast(),
callbacks = (
on_marginal_update = on_marginal_update_callback,
before_iteration = before_iteration_callback,
after_iteration = after_iteration_callback
)
)--- Starting iteration 1 ---
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable γ has been updated. Estimated mean is 1.3999999999854388e-12 with standard deviation 7.483314773470051e-13
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable μ has been updated. Estimated mean is 2.211427712662224e-9 with standard deviation 9.9999999965
--- Iteration 1 has been finished ---
--- Starting iteration 2 ---
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable γ has been updated. Estimated mean is 0.01268107438197917 with standard deviation 0.006778319376223164
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable μ has been updated. Estimated mean is 2.72880761846543 with standard deviation 3.690932291368008
--- Iteration 2 has been finished ---
--- Starting iteration 3 ---
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable γ has been updated. Estimated mean is 0.09839399669397612 with standard deviation 0.05259380350631808
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable μ has been updated. Estimated mean is 3.0962467644426255 with standard deviation 1.4114357905893622
--- Iteration 3 has been finished ---
--- Starting iteration 4 ---
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable γ has been updated. Estimated mean is 0.5793698248292362 with standard deviation 0.30968619782089085
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable μ has been updated. Estimated mean is 3.1483143841582715 with standard deviation 0.5865280067550619
--- Iteration 4 has been finished ---
--- Starting iteration 5 ---
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][τ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable γ has been updated. Estimated mean is 1.8314062601638696 with standard deviation 0.9789278230751689
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
[Log]: [NormalMeanPrecision][μ]: DeferredMessage([ use `as_message` to compute the message ])
Latent variable μ has been updated. Estimated mean is 3.1557362051368747 with standard deviation 0.33028256056859934
--- Iteration 5 has been finished ---We can see the order of message update events. Note that ReactiveMP may decide to compute messages lazily, in which case the actual computation of the value of a message will be deferred until later moment. In this case, LoggerPipelineStage will report DeferredMessage.
Using RxInferBenchmarkCallbacks for Performance Analysis
RxInfer provides a built-in benchmarking callback structure called RxInferBenchmarkCallbacks that helps collect timing information during the inference procedure. This structure aggregates timing information across multiple runs, allowing you to track performance statistics (min/max/average/etc.) of your model's creation and inference procedure.
Here's how to use it:
using RxInfer
@model function iid_normal(y)
μ ~ Normal(mean = 0.0, variance = 100.0)
γ ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = μ, precision = γ)
end
init = @initialization begin
q(μ) = vague(NormalMeanVariance)
end
# Create a benchmark callbacks instance to track performance
benchmark_callbacks = RxInferBenchmarkCallbacks()
# Run inference multiple times to gather statistics
for i in 1:3 # Usually you'd want more runs for better statistics
infer(
model = iid_normal(),
data = (y = dataset, ),
constraints = MeanField(),
iterations = 5,
initialization = init,
callbacks = benchmark_callbacks
)
endIn order to nicely display the statistics, you may want to install PrettyTables.jl package. It is not bundled with RxInfer by default, but, if installed manually, it makes the output more readable.
using PrettyTables
# Display the benchmark statistics in a nicely formatted table
PrettyTables.pretty_table(benchmark_callbacks)RxInfer inference benchmark statistics: 3 evaluations
╭────────────────┬────────────┬────────────┬────────────┬────────────┬────────────╮
│ Operation │ Min │ Max │ Mean │ Median │ Std │
├────────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ Model creation │ 282.167 μs │ 475.949 μs │ 350.698 μs │ 293.979 μs │ 108.631 μs │
│ Inference │ 46.146 μs │ 100.719 μs │ 66.882 μs │ 53.781 μs │ 29.551 μs │
│ Iteration │ 3.858 μs │ 41.638 μs │ 8.786 μs │ 4.558 μs │ 10.021 μs │
╰────────────────┴────────────┴────────────┴────────────┴────────────┴────────────╯The RxInferBenchmarkCallbacks structure collects timestamps at various stages of the inference process:
- Before and after model creation
- Before and after inference starts/ends
- Before and after each iteration
- Before and after autostart (for streaming inference)
RxInfer.RxInferBenchmarkCallbacks — Type
RxInferBenchmarkCallbacks(; capacity = RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)A callback structure for collecting timing information during the inference procedure. This structure collects timestamps for various stages of the inference process and aggregates them across multiple runs, allowing you to track performance statistics (min/max/average/etc.) of your model's creation and inference procedure. The structure supports pretty printing by default, displaying timing statistics in a human-readable format.
The structure uses circular buffers with a default capacity of 1000 entries to store timestamps, which helps to limit memory usage in long-running applications. Use RxInferBenchmarkCallbacks(; capacity = N) to change the buffer capacity. See also RxInfer.get_benchmark_stats(callbacks).
Fields
before_model_creation_ts: CircularBuffer of timestamps before model creationafter_model_creation_ts: CircularBuffer of timestamps after model creationbefore_inference_ts: CircularBuffer of timestamps before inference startsafter_inference_ts: CircularBuffer of timestamps after inference endsbefore_iteration_ts: CircularBuffer of vectors of timestamps before each iterationafter_iteration_ts: CircularBuffer of vectors of timestamps after each iterationbefore_autostart_ts: CircularBuffer of timestamps before autostartafter_autostart_ts: CircularBuffer of timestamps after autostart
Example
# Create a callbacks instance to track performance
callbacks = RxInferBenchmarkCallbacks()
# Run inference multiple times to gather statistics
for _ in 1:10
infer(
model = my_model(),
data = my_data,
callbacks = callbacks
)
end
# Display the timing statistics (you need to install `PrettyTables.jl` to use `pretty_table` function)
using PrettyTables
PrettyTables.pretty_table(callbacks)RxInfer.get_benchmark_stats — Function
get_benchmark_stats(callbacks::RxInferBenchmarkCallbacks)Returns a matrix containing benchmark statistics for different operations in the inference process. The matrix contains the following columns:
- Operation name (String)
- Minimum time (Float64)
- Maximum time (Float64)
- Mean time (Float64)
- Median time (Float64)
- Standard deviation (Float64)
Each row represents a different operation (model creation, inference, iteration, autostart). Times are in nanoseconds.
RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY — Constant
DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITYThe default capacity of the circular buffers used to store timestamps in the RxInferBenchmarkCallbacks structure.
By default, the RxInferBenchmarkCallbacks structure uses a circular buffer with a limited capacity to store timestamps. This helps limit memory usage in long-running applications. You can change the buffer capacity by passing a different value to the capacity keyword argument of the RxInferBenchmarkCallbacks constructor.
This information can be used to:
- Track performance statistics (min/max/average) of your inference procedure
- Identify performance variability across runs
- Monitor the time spent in different stages of inference
- Establish performance baselines for your models
- Detect performance regressions
The timestamps are collected using time_ns() for high precision timing measurements and are automatically formatted into human-readable durations when displayed.
The timing measurements include all overhead from the Julia runtime and may vary between runs. For more precise benchmarking of specific code sections, consider using the BenchmarkTools.jl package. When gathering performance statistics, consider running multiple iterations to get more reliable metrics.