Skip to content

Commit

Permalink
Merge pull request #342 from blolt/node_contraction
Browse files Browse the repository at this point in the history
Add custom node contraction
  • Loading branch information
wouterwln authored Sep 30, 2024
2 parents 9facd12 + f5150d9 commit fc028f2
Show file tree
Hide file tree
Showing 10 changed files with 356 additions and 16 deletions.
25 changes: 24 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactiveMP = "a194aa59-28ba-4574-a09c-4a745416d6e3"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
Expand All @@ -43,6 +44,7 @@ ProgressMeter = "1.0.0"
Random = "1.9"
ReactiveMP = "~4.4.1"
Reexport = "1.2.0"
Static = "0.8.10"
Rocket = "1.8.0"
TupleTools = "1.2.0"
julia = "1.10"
Expand Down Expand Up @@ -70,4 +72,25 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"

[targets]
test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "CpuId", "Dates", "Distributed", "Documenter", "ExponentialFamilyProjection", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers", "ReTestItems"]
test = [
"Test",
"Pkg",
"Logging",
"InteractiveUtils",
"TestSetExtensions",
"Coverage",
"CpuId",
"Dates",
"Distributed",
"Documenter",
"ExponentialFamilyProjection",
"Plots",
"BenchmarkCI",
"BenchmarkTools",
"PkgBenchmark",
"Aqua",
"StableRNGs",
"StatsFuns",
"Optimisers",
"ReTestItems",
]
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand Down
113 changes: 113 additions & 0 deletions docs/src/manuals/model-specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,119 @@ model = RxInfer.create_model(conditioned)
GraphPlot.gplot(RxInfer.getmodel(model))
```

## Node Contraction

RxInfer's model specification extension for GraphPPL supports a feature called _node contraction_. This feature allows you to _contract_ (or _replace_) a submodel with a corresponding factor node. Node contraction can be useful in several scenarios:

- When running inference in a submodel is computationally expensive
- When a submodel contains many variables whose inference results are not of primary importance
- When specialized message passing update rules can be derived for variables in the Markov blanket of the submodel

Let's illustrate this concept with a simple example. We'll first create a basic submodel and then allow the inference backend to replace it with a corresponding node that has well-defined message update rules.

```@example node-contraction
using RxInfer, Plots
@model function ShiftedNormal(data, mean, precision, shift)
shifted_mean := mean + shift
data ~ Normal(mean = shifted_mean, precision = precision)
end
@model function Model(data, precision, shift)
mean ~ Normal(mean = 15.0, var = 1.0)
data ~ ShiftedNormal(mean = mean, precision = precision, shift = shift)
end
result = infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, )
)
plot(title = "Inference results over `mean`")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
```

As we can see, we can run inference on this model. We can also visualize the model's structure, as shown in the [Model structure visualisation](@ref user-guide-model-specification-visualization) section.

```@example node-contraction
using Cairo, GraphPlot
GraphPlot.gplot(getmodel(result.model))
```

Now, let's create an optimized version of the `ShiftedNormal` submodel as a standalone node with its own message passing update rules.

!!! note
Creating correct message passing update rules is beyond the scope of this section. For more information about custom message passing update rules, refer to the [Custom Node](@ref create-node) section.

```@example node-contraction
@node typeof(ShiftedNormal) Stochastic [ data, mean, precision, shift ]
@rule typeof(ShiftedNormal)(:mean, Marginalisation) (q_data::PointMass, q_precision::PointMass, q_shift::PointMass, ) = begin
return @call_rule NormalMeanPrecision(:μ, Marginalisation) (q_out = PointMass(mean(q_data) - mean(q_shift)), q_τ = q_precision)
end
result_with_contraction = infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, ),
allow_node_contraction = true
)
using Test #hide
@test result.posteriors[:mean] ≈ result_with_contraction.posteriors[:mean] #hide
plot(title = "Inference results over `mean` with node contraction")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result_with_contraction.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
```

As you can see, the inference result is identical to the previous case. However, the structure of the model is different:

```@example node-contraction
GraphPlot.gplot(getmodel(result_with_contraction.model))
```

With node contraction, we no longer have access to the variables defined inside the `ShiftedNormal` submodel, as it has been contracted to a single factor node. It's worth noting that this feature heavily relies on existing message passing update rules for the submodel. However, it can also be combined with another useful inference technique [where no explicit message passing update rules are required](@ref inference-undefinedrules).

We can also verify that node contraction indeed improves the performance of the inference:

```@example node-contraction
using BenchmarkTools
benchmark_without_contraction = @benchmark infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, )
)
benchmark_with_contraction = @benchmark infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, ),
allow_node_contraction = true
)
using Test #hide
@test benchmark_with_contraction.allocs < benchmark_without_contraction.allocs #hide
@test mean(benchmark_with_contraction.times) < mean(benchmark_without_contraction.times) #hide
@test median(benchmark_with_contraction.times) < median(benchmark_without_contraction.times) #hide
@test minimum(benchmark_with_contraction.times) < minimum(benchmark_without_contraction.times) #hide
nothing #hide
```

Let's examine the benchmark results:

```@example node-contraction
benchmark_without_contraction
```

```@example node-contraction
benchmark_with_contraction
```

As we can see, the inference with node contraction runs faster due to the simplified model structure and optimized message update rules.
This performance improvement is reflected in reduced execution time and fewer memory allocations.

### [Node creation options](@id user-guide-model-specification-node-creation-options)

`GraphPPL` allows to pass optional arguments to the node creation constructor with the `where { options... }` options specification syntax.
Expand Down
8 changes: 6 additions & 2 deletions src/inference/batch.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Static

"""
InferenceResult
Expand Down Expand Up @@ -115,6 +117,8 @@ function batch_inference(;
free_energy = false,
# Default BFE stream checks
free_energy_diagnostics = DefaultObjectiveDiagnosticChecks,
# Enables node contraction with additional implementation, optional, defaults to false.
allow_node_contraction = false,
# Show progress module, optional, defaults to false
showprogress = false,
# Inference cycle callbacks
Expand Down Expand Up @@ -156,11 +160,11 @@ function batch_inference(;
end

# The `_model` here still must be a `ModelGenerator`
_model = GraphPPL.with_plugins(model, modelplugins)
_model = GraphPPL.with_backend(GraphPPL.with_plugins(model, modelplugins), ReactiveMPGraphPPLBackend(Static.static(allow_node_contraction)))

infer_check_dicttype(:data, data)

# If `predictvars` is specified implicitly as `KeepEach` or `KeepLast`, we replace it with the same value for each data variable
# If `predictvars` is specified implicitly as `KeepEach` or `KeepLast`, we replace it a the same value for each data variable
if (predictvars === KeepEach() || predictvars === KeepLast())
if !isnothing(data)
predictoption = predictvars
Expand Down
3 changes: 3 additions & 0 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ function infer(;
iterations = nothing,
free_energy = false,
free_energy_diagnostics = DefaultObjectiveDiagnosticChecks,
allow_node_contraction = false,
showprogress = false, # batch specific
catch_exception = false, # batch specific
callbacks = nothing,
Expand Down Expand Up @@ -315,6 +316,7 @@ function infer(;
iterations = iterations,
free_energy = free_energy,
free_energy_diagnostics = free_energy_diagnostics,
allow_node_contraction = allow_node_contraction,
showprogress = showprogress,
callbacks = callbacks,
addons = addons,
Expand All @@ -340,6 +342,7 @@ function infer(;
iterations = iterations,
free_energy = free_energy,
free_energy_diagnostics = free_energy_diagnostics,
allow_node_contraction = allow_node_contraction,
autostart = autostart,
callbacks = callbacks,
addons = addons,
Expand Down
5 changes: 4 additions & 1 deletion src/inference/streaming.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Static

"""
RxInferenceEngine
Expand Down Expand Up @@ -455,6 +457,7 @@ function streaming_inference(;
iterations = nothing,
free_energy = false,
free_energy_diagnostics = DefaultObjectiveDiagnosticChecks,
allow_node_contraction = false,
autostart = true,
events = nothing,
addons = nothing,
Expand Down Expand Up @@ -510,7 +513,7 @@ function streaming_inference(;
end

# The `_model` here still must be a `ModelGenerator`
_model = GraphPPL.with_plugins(model, modelplugins)
_model = GraphPPL.with_backend(GraphPPL.with_plugins(model, modelplugins), ReactiveMPGraphPPLBackend(Static.static(allow_node_contraction)))
_autoupdates = something(autoupdates, EmptyAutoUpdateSpecification)

check_model_generator_compatibility(_autoupdates, _model)
Expand Down
36 changes: 30 additions & 6 deletions src/model/graphppl.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import GraphPPL
import MacroTools
import ExponentialFamily
import Static

import MacroTools: @capture

"""
A backend for GraphPPL that uses ReactiveMP for inference.
"""
struct ReactiveMPGraphPPLBackend end
struct ReactiveMPGraphPPLBackend{T}
should_contract_node::T
end

# Model specification with `@model` macro

Expand Down Expand Up @@ -143,7 +146,7 @@ See the documentation to [`GraphPPL.@model`](https:/ReactiveBayes/Gr
$(begin io = IOBuffer(); RxInfer.show_tilderhs_alias(io); String(take!(io)) end)
"""
macro model(model_specification)
return esc(GraphPPL.model_macro_interior(ReactiveMPGraphPPLBackend, model_specification))
return esc(GraphPPL.model_macro_interior(ReactiveMPGraphPPLBackend{Static.False}, model_specification))
end

# Backend specific methods
Expand All @@ -159,11 +162,25 @@ function GraphPPL.NodeBehaviour(backend::ReactiveMPGraphPPLBackend, ::ReactiveMP
return GraphPPL.Stochastic()
end

function GraphPPL.NodeType(::ReactiveMPGraphPPLBackend, something::F) where {F}
# Fallback to the default behaviour
# If node contraction is enabled, we need to check if the node is predefined in `ReactiveMP`
# if this is the case, we use the `Atomic` node type, otherwise we fallback to the `DefaultBackend`
function GraphPPL.NodeType(backend::ReactiveMPGraphPPLBackend{Static.True}, something::F) where {F}
return GraphPPL.NodeType(backend, ReactiveMP.is_predefined_node(something), something)
end
function GraphPPL.NodeType(backend::ReactiveMPGraphPPLBackend{Static.True}, ::ReactiveMP.UndefinedNodeFunctionalForm, something::F) where {F}
# Fallback to the default behaviour if the node is not predefined
return GraphPPL.NodeType(ReactiveMPGraphPPLBackend(Static.False()), something)
end
function GraphPPL.NodeType(backend::ReactiveMPGraphPPLBackend{Static.True}, ::ReactiveMP.PredefinedNodeFunctionalForm, something::F) where {F}
# Fallback to the default behaviour if the node is not predefined
return GraphPPL.Atomic()
end

# Fallback to the default behaviour
function GraphPPL.NodeType(::ReactiveMPGraphPPLBackend{Static.False}, something::F) where {F}
return GraphPPL.NodeType(GraphPPL.DefaultBackend(), something)
end
function GraphPPL.aliases(::ReactiveMPGraphPPLBackend, something::F) where {F}
function GraphPPL.aliases(::ReactiveMPGraphPPLBackend{Static.False}, something::F) where {F}
# Fallback to the default behaviour
return GraphPPL.aliases(GraphPPL.DefaultBackend(), something)
end
Expand Down Expand Up @@ -208,7 +225,14 @@ function GraphPPL.default_parametrization(backend::ReactiveMPGraphPPLBackend, no
end

function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend})
return ReactiveMPGraphPPLBackend()
return ReactiveMPGraphPPLBackend(Static.False())
end

function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend{Static.True}})
return ReactiveMPGraphPPLBackend(Static.True())
end
function GraphPPL.instantiate(::Type{ReactiveMPGraphPPLBackend{Static.False}})
return ReactiveMPGraphPPLBackend(Static.False())
end

# Node specific aliases
Expand Down
Loading

0 comments on commit fc028f2

Please sign in to comment.