Non-conjugate Inference
The RxInfer package excels in scenarios where the model uses conjugate priors for hidden states. Conjugate priors allow Bayesian inference to utilize pre-computed analytical update rules, significantly speeding up the inference process. For instance, the conjugate prior for the parameter of a Bernoulli
distribution is the Beta
distribution. The conjugate prior for the mean parameter of a Normal
distribution is another Normal
distribution, and the conjugate prior for the precision parameter of a Normal
distribution is the Gamma
Non-conjugate Structures
However, models often contain non-conjugate structures, which prevent RxInfer
from performing efficient inference. Non-conjugate priors occur when the prior and the likelihood do not result in a posterior that belongs to the same family as the prior. This complicates the inference process because it requires approximations or numerical methods instead of simple analytical updates.
Example Scenario
Consider the scenario where we assign the Beta
distribution as a prior for the mean parameter of a Normal
distribution. Let's explore what happens in this case with an example.
First, we generate some synthetic data:
using Distributions, ExponentialFamily, Plots, StableRNGs
# The model will infer the hidden parameters from data
hidden_mean = 0.2
hidden_precision = 0.8
hidden_distribution = NormalMeanPrecision(hidden_mean, hidden_precision)
number_of_datapoints = 1000
data = rand(StableRNG(42), hidden_distribution, number_of_datapoints)
histogram(data; normalize = :pdf)
Next, we specify the model. Suppose we believe the data follows a Normal
distribution, and we are confident that the mean parameter is between 0
and 1
. The Beta
distribution is a logical choice for the prior of the mean parameter because it models a continuous variable in the range from 0
to 1
. Similarly, we assign a Beta
prior for the precision parameter, assuming it also lies between 0
and 1
using RxInfer
@model function non_conjugate_model(y)
m ~ Beta(1, 1)
p ~ Beta(1, 1)
y .~ Normal(mean = m, precision = p)
If we attempt inference with this model, RxInfer
will throw an error because the necessary computational rules for such a model are not available in closed form. This is due to the non-conjugate nature of the priors used.
Addressing Non-conjugacy with ExponentialFamilyProjection
To overcome this limitation, RxInfer
integrates with the ExponentialFamilyProjection
package. This package re-projects non-conjugate relationships back into a member of the exponential family at the cost of some accuracy.
supports non-conjugate inference for completeness, but be aware that inference execution times may increase significantly. This is because non-conjugate models require more complex computations, often involving sampling-based approximations.
Specifying Constraints
The projection constraint must be specified using the @constraints
macro. For example:
using ExponentialFamilyProjection
@constraints function non_conjugate_model_constraints()
# project variational posterior over `m` to `Beta`
q(m) :: ProjectedTo(Beta)
# project variational posterior over `p` to `Beta`
q(p) :: ProjectedTo(Beta)
# `m` and `p` are jointly independent
q(m, p) = q(m)q(p)
non_conjugate_model_constraints (generic function with 1 method)
These constraints specify that the posterior distribution for the hidden variable m
must be re-projected to a Beta
distribution to cover the region from 0
to 1
. The same applies to the variable p
Note that the distribution specified in the @constraints
does not need to match the distribution specified as a prior. For example, we could use a Gamma
distribution as a prior and a Beta
distribution as a posterior. The only requirement is that the support of the posterior distribution must be the same as or smaller than that of the prior.
We also assume that m
and p
are jointly independent with the q(m, p) = q(m)q(p)
specification. Dropping the assumption of joint independence would require initializing messages for m
and p
without guarantees of convergence. Read more about factorization constraints in the Constraints Specification guide.
The ProjectedTo
structure is defined in the ExponentialFamilyProjection
package. To fully explore its capabilities and hyper-parameters, we invite you to read the detailed documentation.
We also need to initialize the inference procedure due to the factorization constraints. Read more about initialization in the corresponding section.
initialization = @initialization begin
q(m) = Beta(1, 1)
q(p) = Beta(1, 1)
Initial state:
q(m) = Beta{Float64}(α=1.0, β=1.0)
q(p) = Beta{Float64}(α=1.0, β=1.0)
Running the Inference
With everything set up, we can run the inference procedure:
result = infer(
model = non_conjugate_model(),
data = (y = data,),
constraints = non_conjugate_model_constraints(),
initialization = initialization,
iterations = 25,
free_energy = true
Inference results:
Posteriors | available for (m, p)
Free Energy: | Real[1563.75, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15 … 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15, 1559.15]
Analyzing the Results
Let's analyze the results using the StatsPlots
package to visualize the resulting posteriors over individual VMP iterations:
using StatsPlots
@gif for (i, q) in enumerate(zip(result.posteriors[:m], result.posteriors[:p]))
q_m = q[1]
q_p = q[2]
p1 = plot(q_m, label = "Inferred `m`", fill = 0, fillalpha = 0.2)
p1 = vline!(p1, [hidden_mean], label = "Hidden `m`")
p2 = plot(q_p, label = "Inferred `p`", fill = 0, fillalpha = 0.2)
p2 = vline!(p2, [hidden_precision], label = "Hidden `p`")
plot(p1, p2; title = "Iteration $i")
end fps = 15

As we can see, the estimated posteriors are quite close to the actual hidden parameters used to generate our dataset. We can also verify the Bethe Free Energy values to ensure our result has converged:
plot(result.free_energy, label = "Bethe Free Energy (per iteration)")
The convergence of the Bethe Free Energy indicates that the inference process has stabilized, and the model parameters have reached an optimal state.
The projection method uses stochastic gradient computations, which may cause fluctuations in the estimates and Bethe Free Energy performance.