Benchmark callbacks
RxInfer provides a built-in callback structure called RxInferBenchmarkCallbacks for collecting timing information during the inference procedure. This structure aggregates timestamps across multiple inference runs, allowing you to track performance statistics (min/max/average/etc.) of your model's creation and inference procedure. For general information about the callbacks system, see Callbacks.
Basic usage
using RxInfer
@model function iid_normal(y)
μ ~ Normal(mean = 0.0, variance = 100.0)
γ ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = μ, precision = γ)
end
init = @initialization begin
q(μ) = vague(NormalMeanVariance)
end
# Warm up to avoid measuring compilation time
# Create a benchmark callbacks instance to track performance
benchmark_callbacks = RxInferBenchmarkCallbacks()
# Run inference multiple times to gather statistics
for i in 1:5
infer(
model = iid_normal(),
data = (y = randn(100),),
constraints = MeanField(),
iterations = 5,
initialization = init,
callbacks = benchmark_callbacks,
)
endThe benchmark callbacks instance accumulates timestamps across multiple calls to infer, making it easy to collect performance statistics over many runs.
Displaying results
Install PrettyTables.jl to display the collected statistics in a nicely formatted table:
using PrettyTables
PrettyTables.pretty_table(benchmark_callbacks)RxInfer inference benchmark statistics: 5 evaluations
╭────────────────┬────────────┬────────────┬────────────┬────────────┬────────────╮
│ Operation │ Min │ Max │ Mean │ Median │ Std │
├────────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ Model creation │ 3.457 ms │ 3.726 ms │ 3.555 ms │ 3.507 ms │ 111.983 μs │
│ Inference │ 1.730 ms │ 1.840 ms │ 1.760 ms │ 1.745 ms │ 44.791 μs │
│ Iteration │ 327.613 μs │ 410.706 μs │ 343.324 μs │ 337.488 μs │ 17.549 μs │
╰────────────────┴────────────┴────────────┴────────────┴────────────┴────────────╯Accessing from model metadata
After model creation, the benchmark callbacks instance is automatically saved into the model's metadata under the :benchmark key. This makes it accessible from the inference result without needing to hold onto the callbacks object separately:
result = infer(
model = iid_normal(),
data = (y = randn(100),),
constraints = MeanField(),
iterations = 5,
initialization = init,
callbacks = RxInferBenchmarkCallbacks(),
)
benchmark = result.model.metadata[:benchmark]
println(benchmark)RxInferBenchmarkCallbacks (1evaluations, use `pretty_table` from `PrettyTables.jl` to display the statistics in a tabular format)Tracked events
The RxInferBenchmarkCallbacks structure collects timestamps at the following stages:
| Event | Batch inference | Streamline inference |
|---|---|---|
Model creation (before/after) | yes | yes |
Inference (before/after) | yes | — |
Each iteration (before/after) | yes | — |
Autostart (before/after) | — | yes |
Buffer capacity
By default, the structure uses circular buffers with a capacity of RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY entries. This limits memory usage in long-running applications. You can change the capacity:
# Store up to 10000 benchmark entries
large_buffer_callbacks = RxInferBenchmarkCallbacks(capacity = 10_000)Programmatic access to statistics
Use RxInfer.get_benchmark_stats to retrieve the raw statistics matrix:
# Use the previously populated benchmark_callbacks
stats = RxInfer.get_benchmark_stats(benchmark_callbacks)
for row in eachrow(stats)
println(row[1], ": min=", round(row[2] / 1e6, digits=2), "ms, mean=", round(row[4] / 1e6, digits=2), "ms")
endModel creation: min=3.46ms, mean=3.55ms
Inference: min=1.73ms, mean=1.76ms
Iteration: min=0.33ms, mean=0.34msThe matrix contains the following columns:
- Operation name (
String) - Minimum time (
Float64, nanoseconds) - Maximum time (
Float64, nanoseconds) - Mean time (
Float64, nanoseconds) - Median time (
Float64, nanoseconds) - Standard deviation (
Float64, nanoseconds)
API Reference
RxInfer.RxInferBenchmarkCallbacks — Type
RxInferBenchmarkCallbacks(; capacity = RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)A callback structure for collecting timing information during the inference procedure. This structure collects timestamps for various stages of the inference process and aggregates them across multiple runs, allowing you to track performance statistics (min/max/average/etc.) of your model's creation and inference procedure. The structure supports pretty printing by default, displaying timing statistics in a human-readable format.
The structure uses circular buffers with a default capacity of 1000 entries to store timestamps, which helps to limit memory usage in long-running applications. Use RxInferBenchmarkCallbacks(; capacity = N) to change the buffer capacity. See also RxInfer.get_benchmark_stats(callbacks).
After model creation, the benchmark callbacks instance is automatically saved into the model's metadata under the :benchmark key (i.e., model.metadata[:benchmark]), making it accessible from the inference result via result.model.metadata[:benchmark].
Fields
before_model_creation_ts: CircularBuffer of timestamps before model creationafter_model_creation_ts: CircularBuffer of timestamps after model creationbefore_inference_ts: CircularBuffer of timestamps before inference startsafter_inference_ts: CircularBuffer of timestamps after inference endsbefore_iteration_ts: CircularBuffer of vectors of timestamps before each iterationafter_iteration_ts: CircularBuffer of vectors of timestamps after each iterationbefore_autostart_ts: CircularBuffer of timestamps before autostartafter_autostart_ts: CircularBuffer of timestamps after autostart
Example
# Create a callbacks instance to track performance
callbacks = RxInferBenchmarkCallbacks()
# Run inference multiple times to gather statistics
for _ in 1:10
result = infer(
model = my_model(),
data = my_data,
callbacks = callbacks
)
end
# Access the benchmark callbacks from the inference result
result.model.metadata[:benchmark] === callbacks # true
# Display the timing statistics (you need to install `PrettyTables.jl` to use `pretty_table` function)
using PrettyTables
PrettyTables.pretty_table(callbacks)RxInfer.get_benchmark_stats — Function
get_benchmark_stats(callbacks::RxInferBenchmarkCallbacks)Returns a matrix containing benchmark statistics for different operations in the inference process. The matrix contains the following columns:
- Operation name (String)
- Minimum time (Float64)
- Maximum time (Float64)
- Mean time (Float64)
- Median time (Float64)
- Standard deviation (Float64)
Each row represents a different operation (model creation, inference, iteration, autostart). Times are in nanoseconds.
RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY — Constant
DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITYThe default capacity of the circular buffers used to store timestamps in the RxInferBenchmarkCallbacks structure.