Benchmark callbacks

RxInfer provides a built-in callback structure called RxInferBenchmarkCallbacks for collecting timing information during the inference procedure. This structure aggregates timestamps across multiple inference runs, allowing you to track performance statistics (min/max/average/etc.) of your model's creation and inference procedure. For general information about the callbacks system, see Callbacks.

Basic usage

using RxInfer

@model function iid_normal(y)
    μ  ~ Normal(mean = 0.0, variance = 100.0)
    γ  ~ Gamma(shape = 1.0, rate = 1.0)
    y .~ Normal(mean = μ, precision = γ)
end

init = @initialization begin
    q(μ) = vague(NormalMeanVariance)
end

# Warm up to avoid measuring compilation time

# Create a benchmark callbacks instance to track performance
benchmark_callbacks = RxInferBenchmarkCallbacks()

# Run inference multiple times to gather statistics
for i in 1:5
    infer(
        model = iid_normal(),
        data = (y = randn(100),),
        constraints = MeanField(),
        iterations = 5,
        initialization = init,
        callbacks = benchmark_callbacks,
    )
end

The benchmark callbacks instance accumulates timestamps across multiple calls to infer, making it easy to collect performance statistics over many runs.

Displaying results

Install PrettyTables.jl to display the collected statistics in a nicely formatted table:

using PrettyTables

PrettyTables.pretty_table(benchmark_callbacks)
RxInfer inference benchmark statistics: 5 evaluations
╭────────────────┬────────────┬────────────┬────────────┬────────────┬────────────╮
│      Operation │        Min │        Max │       Mean │     Median │        Std │
├────────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ Model creation │   3.457 ms │   3.726 ms │   3.555 ms │   3.507 ms │ 111.983 μs │
│      Inference │   1.730 ms │   1.840 ms │   1.760 ms │   1.745 ms │  44.791 μs │
│      Iteration │ 327.613 μs │ 410.706 μs │ 343.324 μs │ 337.488 μs │  17.549 μs │
╰────────────────┴────────────┴────────────┴────────────┴────────────┴────────────╯

Accessing from model metadata

After model creation, the benchmark callbacks instance is automatically saved into the model's metadata under the :benchmark key. This makes it accessible from the inference result without needing to hold onto the callbacks object separately:

result = infer(
    model = iid_normal(),
    data = (y = randn(100),),
    constraints = MeanField(),
    iterations = 5,
    initialization = init,
    callbacks = RxInferBenchmarkCallbacks(),
)


benchmark = result.model.metadata[:benchmark]
println(benchmark)
RxInferBenchmarkCallbacks (1evaluations, use `pretty_table` from `PrettyTables.jl` to display the statistics in a tabular format)

Tracked events

The RxInferBenchmarkCallbacks structure collects timestamps at the following stages:

EventBatch inferenceStreamline inference
Model creation (before/after)yesyes
Inference (before/after)yes
Each iteration (before/after)yes
Autostart (before/after)yes

Buffer capacity

By default, the structure uses circular buffers with a capacity of RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY entries. This limits memory usage in long-running applications. You can change the capacity:

# Store up to 10000 benchmark entries
large_buffer_callbacks = RxInferBenchmarkCallbacks(capacity = 10_000)

Programmatic access to statistics

Use RxInfer.get_benchmark_stats to retrieve the raw statistics matrix:

# Use the previously populated benchmark_callbacks
stats = RxInfer.get_benchmark_stats(benchmark_callbacks)


for row in eachrow(stats)
    println(row[1], ": min=", round(row[2] / 1e6, digits=2), "ms, mean=", round(row[4] / 1e6, digits=2), "ms")
end
Model creation: min=3.46ms, mean=3.55ms
Inference: min=1.73ms, mean=1.76ms
Iteration: min=0.33ms, mean=0.34ms

The matrix contains the following columns:

  1. Operation name (String)
  2. Minimum time (Float64, nanoseconds)
  3. Maximum time (Float64, nanoseconds)
  4. Mean time (Float64, nanoseconds)
  5. Median time (Float64, nanoseconds)
  6. Standard deviation (Float64, nanoseconds)

API Reference

RxInfer.RxInferBenchmarkCallbacksType
RxInferBenchmarkCallbacks(; capacity = RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)

A callback structure for collecting timing information during the inference procedure. This structure collects timestamps for various stages of the inference process and aggregates them across multiple runs, allowing you to track performance statistics (min/max/average/etc.) of your model's creation and inference procedure. The structure supports pretty printing by default, displaying timing statistics in a human-readable format.

The structure uses circular buffers with a default capacity of 1000 entries to store timestamps, which helps to limit memory usage in long-running applications. Use RxInferBenchmarkCallbacks(; capacity = N) to change the buffer capacity. See also RxInfer.get_benchmark_stats(callbacks).

After model creation, the benchmark callbacks instance is automatically saved into the model's metadata under the :benchmark key (i.e., model.metadata[:benchmark]), making it accessible from the inference result via result.model.metadata[:benchmark].

Fields

  • before_model_creation_ts: CircularBuffer of timestamps before model creation
  • after_model_creation_ts: CircularBuffer of timestamps after model creation
  • before_inference_ts: CircularBuffer of timestamps before inference starts
  • after_inference_ts: CircularBuffer of timestamps after inference ends
  • before_iteration_ts: CircularBuffer of vectors of timestamps before each iteration
  • after_iteration_ts: CircularBuffer of vectors of timestamps after each iteration
  • before_autostart_ts: CircularBuffer of timestamps before autostart
  • after_autostart_ts: CircularBuffer of timestamps after autostart

Example

# Create a callbacks instance to track performance
callbacks = RxInferBenchmarkCallbacks()

# Run inference multiple times to gather statistics
for _ in 1:10
    result = infer(
        model = my_model(),
        data = my_data,
        callbacks = callbacks
    )
end

# Access the benchmark callbacks from the inference result
result.model.metadata[:benchmark] === callbacks # true

# Display the timing statistics (you need to install `PrettyTables.jl` to use `pretty_table` function)
using PrettyTables

PrettyTables.pretty_table(callbacks)
source
RxInfer.get_benchmark_statsFunction
get_benchmark_stats(callbacks::RxInferBenchmarkCallbacks)

Returns a matrix containing benchmark statistics for different operations in the inference process. The matrix contains the following columns:

  1. Operation name (String)
  2. Minimum time (Float64)
  3. Maximum time (Float64)
  4. Mean time (Float64)
  5. Median time (Float64)
  6. Standard deviation (Float64)

Each row represents a different operation (model creation, inference, iteration, autostart). Times are in nanoseconds.

source
Note

The timing measurements include all overhead from the Julia runtime and may vary between runs. For more precise benchmarking of specific code sections, consider using the BenchmarkTools.jl package.