Inference execution
The RxInfer inference API supports different types of message-passing algorithms (including hybrid algorithms combining several different types). While RxInfer implements several algorithms to cater to different computational needs and scenarios, the core message-passing algorithms that form the foundation of our inference capabilities are:
Whereas belief propagation computes exact inference for the random variables of interest, the variational message passing (VMP) is an approximation method that can be applied to a larger range of models.
The inference engine itself isn't aware of different algorithm types and simply does message passing between nodes. However, during the model specification stage user may specify different factorisation constraints around factor nodes with the help of the @constraints macro. Different factorisation constraints lead to different message passing update rules. See more documentation about constraints specification in the corresponding section.
Automatic inference specification
RxInfer exports the infer function to quickly run and test your model with both static and asynchronous (real-time) datasets. See more information about the infer function on the separate documentation section:
RxInfer.infer — Functioninfer(
model;
data = nothing,
datastream = nothing,
autoupdates = nothing,
initialization = nothing,
constraints = nothing,
meta = nothing,
options = nothing,
returnvars = nothing,
predictvars = nothing,
historyvars = nothing,
keephistory = nothing,
iterations = nothing,
free_energy = false,
free_energy_diagnostics = DefaultObjectiveDiagnosticChecks,
showprogress = false,
callbacks = nothing,
addons = nothing,
postprocess = DefaultPostprocess(),
warn = true,
events = nothing,
uselock = false,
autostart = true,
catch_exception = false,
session = RxInfer.default_session()
)This function provides a generic way to perform probabilistic inference for batch/static and streamline/online scenarios. Returns either an InferenceResult (batch setting) or RxInferenceEngine (streamline setting) based on the parameters used.
Before using this function, you may want to review common issues and solutions in the Sharp bits of RxInfer section of the documentation.
Arguments
Check the official documentation for more information about some of the arguments.
model: specifies a model generator, requireddata:NamedTupleorDictwith data, required (ordatastreamorpredictvars)datastream: A stream ofNamedTuplewith data, required (ordata)autoupdates = nothing: auto-updates specification, required for streamline inference, see@autoupdatesinitialization = nothing: initialization specification object, optional, see@initializationconstraints = nothing: constraints specification object, or an alias such asMeanField, optional, see@constraintsmeta = nothing: meta specification object, optional, may be required for some models, see@metaoptions = nothing: model creation options, optional, seeReactiveMPInferenceOptionsreturnvars = nothing: return structure info, optional, defaults to return everything at each iterationpredictvars = nothing: return structure info, optional (exclusive for batch inference)historyvars = nothing: history structure info, optional, defaults to no history (exclusive for streamline inference)keephistory = nothing: history buffer size, defaults to empty buffer (exclusive for streamline inference)iterations = nothing: number of iterations, optional, defaults tonothing, the inference engine does not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterationsfree_energy = false: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g.Float64, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jlfree_energy_diagnostics = DefaultObjectiveDiagnosticChecks: free energy diagnostic checks, optional, by default checks for possibleNaNs andInfs.nothingdisables all checks.showprogress = false: show progress module, optional, defaults to false (exclusive for batch inference)catch_exceptionspecifies whether exceptions during the inference procedure should be caught, optional, defaults to false (exclusive for batch inference)callbacks = nothing: inference cycle callbacks, optionaladdons = nothing: inject and send extra computation information along messagespostprocess = DefaultPostprocess(): inference results postprocessing step, optionalevents = nothing: inference cycle events, optional (exclusive for streamline inference)uselock = false: specifies either to use the lock structure for the inference or not, if set to true usesBase.Threads.SpinLock. Accepts customAbstractLock. (exclusive for streamline inference)autostart = true: specifies whether to callRxInfer.starton the created engine automatically or not (exclusive for streamline inference)warn = true: enables/disables warningssession = RxInfer.default_session(): current logging session for the RxInfer invokes, seeSessionfor more details, passnothingto disable logging
Error hints
By default, RxInfer provides helpful error hints with documentation links, solutions, and troubleshooting guidance.
Use RxInfer.disable_inference_error_hint!() to disable error hints or RxInfer.enable_inference_error_hint!() to enable them. Note that changes to error hint settings require a Julia session restart to take effect.
See also: RxInfer.disable_inference_error_hint!, RxInfer.enable_inference_error_hint!
Note on NamedTuples
When passing NamedTuple as a value for some argument, make sure you use a trailing comma for NamedTuples with a single entry. The reason is that Julia treats returnvars = (x = KeepLast()) and returnvars = (x = KeepLast(), ) expressions differently. This first expression creates (or overwrites!) new local/global variable named x with contents KeepLast(). The second expression (note trailing comma) creates NamedTuple with x as a key and KeepLast() as a value assigned for this key.
(x = KeepLast()) # defines a variable `x` with the value `KeepLast()`KeepLast()(x = KeepLast(), ) # defines a NamedTuple with `x` as one of the keys and value `KeepLast()`(x = KeepLast(),)model
Also read the Model Specification section.
The model argument accepts a model specification as its input. The easiest way to create the model is to use the @model macro. For example:
@model function beta_bernoulli(y, a, b)
x ~ Beta(a, b)
y .~ Bernoulli(x)
end
result = infer(
model = beta_bernoulli(a = 1, b = 1),
data = (y = [ true, false, false ], )
)
result.posteriors[:x]Beta{Float64}(α=2.0, β=3.0)The model keyword argument does not accept a ProbabilisticModel instance as a value, as it needs to inject constraints and meta during the inference procedure.
data
Either data or datastream keyword argument are required. Specifying both data and datastream is not supported and will result in an error.
The behavior of the data keyword argument depends on the inference setting (batch or streamline).
The data keyword argument must be a NamedTuple (or Dict) where keys (of Symbol type) correspond to some arguments defined in the model specification. For example, if a model defines y in its argument list
@model function beta_bernoulli(y, a, b)
x ~ Beta(a, b)
y .~ Bernoulli(x)
endand you want to condition on this argument, then the data field must have an :y key (of Symbol type) which holds the data. The values in the data must have the exact same shape as its corresponding variable container. E.g. in the exampl above y is being used in the broadcasting operation, thus it must be a collection of values. a and b arguments, however, could be just single numbers:
result = infer(
model = beta_bernoulli(),
data = (y = [ true, false, false ], a = 1, b = 1)
)
result.posteriors[:x]Beta{Float64}(α=2.0, β=3.0)datastream
Also read the Streamlined Inference section.
The datastream keyword argument must be an observable that supports subscribe! and unsubscribe! functions (e.g., streams from the Rocket.jl package). The elements of the observable must be of type NamedTuple where keys (of Symbol type) correspond to input arguments defined in the model specification, except for those which are listed in the @autoupdates specification. For example, if a model defines y as its argument (which is not part of the @autoupdates specification) the named tuple from the observable must have an :y key (of Symbol type). The values in the named tuple must have the exact same shape as the corresponding variable container.
initialization
Also read the Initialization section.
For specific types of inference algorithms, such as variational message passing, it might be required to initialize (some of) the marginals before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial marginals and messages, you can use the initialization argument in combination with the @initialization macro, such as
init = @initialization begin
# initialize the marginal distribution of x as a vague Normal distribution
# if x is a vector, then it simply uses the same value for all elements
# However, it is also possible to provide a vector of distributions to set each element individually
q(x) = vague(NormalMeanPrecision)
endInitial state:
q(x) = NormalMeanPrecision{Float64}(μ=0.0, w=1.0e-12)
returnvars
returnvars specifies latent variables of interest and their posterior updates. Its behavior depends on the inference type: streamline or batch.
Batch inference:
- Accepts a
NamedTupleorDictof return variable specifications. - Two specifications available:
KeepLast(saves the last update) andKeepEach(saves all updates). - When
iterationsis set, returns every update for each iteration (equivalent toKeepEach()); ifnothing, saves the last update (equivalent toKeepLast()). - Use
iterations = 1to forceKeepEach()for a single iteration or setreturnvars = KeepEach()manually.
result = infer(
...,
returnvars = (
x = KeepLast(),
τ = KeepEach()
)
)Shortcut for setting the same option for all variables:
result = infer(
...,
returnvars = KeepLast() # or KeepEach()
)Streamline inference:
- For each symbol in
returnvars,infercreates an observable stream of posterior updates. - Agents can subscribe to these updates using the
Rocket.jlpackage.
engine = infer(
...,
autoupdates = my_autoupdates,
returnvars = (:x, :τ),
autostart = false
)RxInfer.KeepLast — TypeInstructs the inference engine to keep only the last marginal update and disregard intermediate updates.
RxInfer.KeepEach — TypeInstructs the inference engine to keep each marginal update for all intermediate iterations.
predictvars
predictvars specifies the variables which should be predicted. Similar to returnvars, predictvars accepts a NamedTuple or Dict. There are two specifications:
KeepLast: saves the last update for a variable, ignoring any intermediate results during iterationsKeepEach: saves all updates for a variable for all iterations
result = infer(
...,
predictvars = (
o = KeepLast(),
τ = KeepEach()
)
)historyvars
Also read the Keeping the history of posteriors.
historyvars specifies the variables of interests and the amount of information to keep in history about the posterior updates when performing streamline inference. The specification is similar to the returnvars when applied in batch setting. The historyvars requires keephistory to be greater than zero.
historyvars accepts a NamedTuple or Dict or return var specification. There are two specifications:
KeepLast: saves the last update for a variable, ignoring any intermediate results during iterationsKeepEach: saves all updates for a variable for all iterations
result = infer(
...,
autoupdates = my_autoupdates,
historyvars = (
x = KeepLast(),
τ = KeepEach()
),
keephistory = 10
)It is also possible to set either historyvars = KeepLast() or historyvars = KeepEach() that acts as an alias and sets the given option for all random variables in the model.
result = infer(
...,
autoupdates = my_autoupdates,
historyvars = KeepLast(),
keephistory = 10
)keep_history
Specifies the buffer size for the updates history both for the historyvars and the free_energy buffers in streamline inference.
iterations
Specifies the number of variational (or loopy belief propagation) iterations. By default set to nothing, which is equivalent of doing 1 iteration. However, if set explicitly to 1 the default setting for returnvars changes from KeepLast to KeepEach.
free_energy
Batch inference:
Specifies if the infer function should return Bethe Free Energy (BFE) values.
- Optionally accepts a floating-point type (e.g.,
Float64) for improved BFE computation performance, but restricts the use of automatic differentiation packages.
Streamline inference:
Specifies if the infer function should create an observable stream of Bethe Free Energy (BFE) values, computed at each VMP iteration.
- When
free_energy = trueandkeephistory > 0, additional fields are exposed in the engine for accessing the history of BFE updates.engine.free_energy_history: Averaged BFE history over VMP iterations.engine.free_energy_final_only_history: BFE history of values computed in the last VMP iterations for each observation.engine.free_energy_raw_history: Raw BFE history.
free_energy_diagnostics
This settings specifies either a single or a tuple of diagnostic checks for Bethe Free Energy values stream. By default checks for NaNs and Infs. See also RxInfer.ObjectiveDiagnosticCheckNaNs and RxInfer.ObjectiveDiagnosticCheckInfs. Pass nothing to disable any checks.
options
RxInfer.ReactiveMPInferenceOptions — TypeReactiveMPInferenceOptions(; kwargs...)Creates model inference options object. The list of available options is present below.
Options
limit_stack_depth: limits the stack depth for computing messages, helps withStackOverflowErrorfor some huge models, but reduces the performance of inference backend. Accepts integer as an argument that specifies the maximum number of recursive depth. Lower is better for stack overflow error, but worse for performance.warn: (optional) flag to suppress warnings. Warnings are not displayed if set tofalse. Defaults totrue.force_marginal_computation: (optional) flag to force computation of marginals even when not explicitly requested. Defaults tofalse.
Advanced options
scheduler: changes the scheduler of reactive streams, see Rocket.jl for more info, defaults toAsapScheduler.rulefallback: specifies a global message update rule fallback for cases when a specific message update rule is not available. ConsultReactiveMPdocumentation for the list of available callbacks.
See also: infer
catch_exception
The catch_exception keyword argument specifies whether exceptions during the batch inference procedure should be caught in the error field of the result. By default, if exception occurs during the inference procedure the result will be lost. Set catch_exception = true to obtain partial result for the inference in case if an exception occurs. Use RxInfer.issuccess and RxInfer.iserror function to check if the inference completed successfully or failed. If an error occurs, the error field will store a tuple, where first element is the exception itself and the second element is the caught backtrace. Use the stacktrace function with the backtrace as an argument to recover the stacktrace of the error. Use Base.showerror function to display the error.
RxInfer.issuccess — FunctionChecks if the InferenceResult object does not contain an error.
RxInfer.iserror — FunctionChecks if the InferenceResult object contains an error.
callbacks
The inference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the inference function, e.g.
result = infer(
...,
callbacks = (
on_marginal_update = (model, name, update) -> println("\$(name) has been updated: \$(update)"),
after_inference = (args...) -> println("Inference has been completed")
)
)The callbacks keyword argument accepts a named-tuple of 'name = callback' pairs. The list of all possible callbacks for different inference setting (batch or streamline) and their arguments is present below:
before_model_creation()after_model_creation(model::ProbabilisticModel)
Exlusive for batch inference
on_marginal_update(model::ProbabilisticModel, name::Symbol, update)before_inference(model::ProbabilisticModel)before_iteration(model::ProbabilisticModel, iteration::Int)::Boolbefore_data_update(model::ProbabilisticModel, data)after_data_update(model::ProbabilisticModel, data)after_iteration(model::ProbabilisticModel, iteration::Int)::Boolafter_inference(model::ProbabilisticModel)
before_iteration and after_iteration callbacks are allowed to return true/false value. true indicates that iterations must be halted and no further inference should be made.
Exlusive for streamline inference
before_autostart(engine::RxInferenceEngine)after_autostart(engine::RxInferenceEngine)
addons
The addons field extends the default message computation rules with some extra information, e.g. computing log-scaling factors of messages or saving debug-information. Accepts a single addon or a tuple of addons. Automatically changes the default value of the postprocess argument to NoopPostprocess.
postprocess
Also read the Inference results postprocessing section.
The postprocess keyword argument controls whether the inference results must be modified in some way before exiting the inference function. By default, the inference function uses the DefaultPostprocess strategy, which by default removes the Marginal wrapper type from the results. Change this setting to NoopPostprocess if you would like to keep the Marginal wrapper type, which might be useful in the combination with the addons argument. If the addons argument has been used, automatically changes the default strategy value to NoopPostprocess.
Error hints
By default, RxInfer provides helpful error hints when an error occurs during inference. This, for example, includes links to relevant documentation, common solutions and troubleshooting steps, information about where to get help, and suggestions for providing good bug reports.
Use RxInfer.disable_inference_error_hint! to disable error hints or RxInfer.enable_inference_error_hint! to enable them. Note that the change requires a Julia session restart to take effect.
RxInfer.disable_inference_error_hint! — Functiondisable_inference_error_hint!()Disable error hints that are shown when an error occurs during inference.
The change requires a Julia session restart to take effect. When disabled, only the raw error will be shown without additional context or suggestions.
See also: enable_inference_error_hint!, infer
RxInfer.enable_inference_error_hint! — Functionenable_inference_error_hint!()Enable error hints that are shown when an error occurs during inference.
The change requires a Julia session restart to take effect. When enabled, errors during the inference call will include:
- Links to relevant documentation
- Common solutions and troubleshooting steps
- Information about where to get help
See also: disable_inference_error_hint!, infer
Where to go next?
Read more explanation about the other keyword arguments in the Streamlined (online) inferencesection or check out the Static Inference section or check some more advanced examples.