Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InplaceLogpdf wrapper type #12

Merged
merged 6 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,5 @@ BayesBase.distribution_typewrapper
BayesBase.CountingReal
BayesBase.Infinity
BayesBase.MinusInfinity
BayesBase.InplaceLogpdf
```
69 changes: 69 additions & 0 deletions src/statsfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,72 @@
function Base.:(==)(left::CountingReal{T}, right::CountingReal{T}) where {T}
return (value(left) == value(right)) && (infinities(left) == infinities(right))
end

"""
InplaceLogpdf(logpdf!)

Wraps a `logpdf!` function in a type that can later on be used for dispatch.
The sole purpose of this wrapper type is to allow for in-place logpdf operation on a batch of samples.
Accepts a function `logpdf!` that takes two arguments: `out` and `sample` and writes the logpdf of the sample to the `out` array.
A regular `logpdf` function can be converted to `logpdf!` by using `convert(InplaceLogpdf, logpdf)`.

```jldoctest

Check failure on line 338 in src/statsfuns.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:338-358 ```jldoctest julia> using Distributions, BayesBase julia> d = Beta(2, 3); julia> inplace = convert(BayesBase.InplaceLogpdf, (sample) -> logpdf(d, sample)); julia> out = zeros(9); julia> inplace(out, 0.1:0.1:0.9) 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173 ``` Subexpression: using Distributions, BayesBase Evaluated output: ERROR: ArgumentError: Package Distributions not found in current path. - Run `import Pkg; Pkg.add("Distributions")` to install the Distributions package. Stacktrace: [1] macro expansion @ ./loading.jl:1772 [inlined] [2] macro expansion @ ./lock.jl:267 [inlined] [3] __require(into::Module, mod::Symbol) @ Base ./loading.jl:1753 [4] #invoke_in_world#3 @ ./essentials.jl:926 [inlined] [5] invoke_in_world @ ./essentials.jl:923 [inlined] [6] require(into::Module, mod::Symbol) @ Base ./loading.jl:1746 Expected output: diff = Warning: Diff output requires color. ERROR: ArgumentError: Package Distributions not found in current path. - Run `import Pkg; Pkg.add("Distributions")` to install the Distributions package. Stacktrace: [1] macro expansion @ ./loading.jl:1772 [inlined] [2] macro expansion @ ./lock.jl:267 [inlined] [3] __require(into::Module, mod::Symbol) @ Base ./loading.jl:1753 [4] #invoke_in_world#3 @ ./essentials.jl:926 [inlined] [5] invoke_in_world @ ./essentials.jl:923 [inlined] [6] require(into::Module, mod::Symbol) @ Base ./loading.jl:1746

Check failure on line 338 in src/statsfuns.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:338-358 ```jldoctest julia> using Distributions, BayesBase julia> d = Beta(2, 3); julia> inplace = convert(BayesBase.InplaceLogpdf, (sample) -> logpdf(d, sample)); julia> out = zeros(9); julia> inplace(out, 0.1:0.1:0.9) 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173 ``` Subexpression: d = Beta(2, 3); Evaluated output: ERROR: UndefVarError: `Beta` not defined Stacktrace: [1] top-level scope @ none:1 Expected output: diff = Warning: Diff output requires color. ERROR: UndefVarError: `Beta` not defined Stacktrace: [1] top-level scope @ none:1

Check failure on line 338 in src/statsfuns.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:338-358 ```jldoctest julia> using Distributions, BayesBase julia> d = Beta(2, 3); julia> inplace = convert(BayesBase.InplaceLogpdf, (sample) -> logpdf(d, sample)); julia> out = zeros(9); julia> inplace(out, 0.1:0.1:0.9) 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173 ``` Subexpression: inplace(out, 0.1:0.1:0.9) Evaluated output: ERROR: UndefVarError: `d` not defined Stacktrace: [1] (::var"#1#2")(sample::Float64) @ Main ./none:1 [2] map!(f::var"#1#2", dest::Vector{Float64}, A::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) @ Base ./abstractarray.jl:3278 [3] #4 @ ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:392 [inlined] [4] (::BayesBase.InplaceLogpdf{BayesBase.var"#4#5"{var"#1#2"}})(out::Vector{Float64}, x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) @ BayesBase ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:387 [5] top-level scope @ none:1 Expected output: 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173 diff = Warning: Diff output requires color. 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173ERROR: UndefVarError: `d` not defined Stacktrace: [1] (::var"#1#2")(sample::Float64) @ Main ./none:1 [2] map!(f::var"#1#2", dest::Vector{Float64}, A::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) @ Base ./abstractarray.jl:3278 [3] #4 @ ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:392 [inlined] [4] (::BayesBase.InplaceLogpdf{BayesBase.var"#4#5"{var"#1#2"}})(out::Vector{Float64}, x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) @ BayesBase ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:387 [5] top-level scope @ none:1
julia> using Distributions, BayesBase

julia> d = Beta(2, 3);

julia> inplace = convert(BayesBase.InplaceLogpdf, (sample) -> logpdf(d, sample));

julia> out = zeros(9);

julia> inplace(out, 0.1:0.1:0.9)
9-element Vector{Float64}:
-0.028399474521697776
0.42918163472548043
0.5675839575845996
0.5469646703818638
0.4054651081081646
0.14149956227369964
-0.2797139028026039
-0.9571127263944104
-2.2256240518579173
```

```jldoctest

Check failure on line 360 in src/statsfuns.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:360-380 ```jldoctest julia> using Distributions, BayesBase julia> d = Beta(2, 3); julia> inplace = BayesBase.InplaceLogpdf((out, sample) -> logpdf!(out, d, sample)); julia> out = zeros(9); julia> inplace(out, 0.1:0.1:0.9) 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173 ``` Subexpression: using Distributions, BayesBase Evaluated output: ERROR: ArgumentError: Package Distributions not found in current path. - Run `import Pkg; Pkg.add("Distributions")` to install the Distributions package. Stacktrace: [1] macro expansion @ ./loading.jl:1772 [inlined] [2] macro expansion @ ./lock.jl:267 [inlined] [3] __require(into::Module, mod::Symbol) @ Base ./loading.jl:1753 [4] #invoke_in_world#3 @ ./essentials.jl:926 [inlined] [5] invoke_in_world @ ./essentials.jl:923 [inlined] [6] require(into::Module, mod::Symbol) @ Base ./loading.jl:1746 Expected output: diff = Warning: Diff output requires color. ERROR: ArgumentError: Package Distributions not found in current path. - Run `import Pkg; Pkg.add("Distributions")` to install the Distributions package. Stacktrace: [1] macro expansion @ ./loading.jl:1772 [inlined] [2] macro expansion @ ./lock.jl:267 [inlined] [3] __require(into::Module, mod::Symbol) @ Base ./loading.jl:1753 [4] #invoke_in_world#3 @ ./essentials.jl:926 [inlined] [5] invoke_in_world @ ./essentials.jl:923 [inlined] [6] require(into::Module, mod::Symbol) @ Base ./loading.jl:1746

Check failure on line 360 in src/statsfuns.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:360-380 ```jldoctest julia> using Distributions, BayesBase julia> d = Beta(2, 3); julia> inplace = BayesBase.InplaceLogpdf((out, sample) -> logpdf!(out, d, sample)); julia> out = zeros(9); julia> inplace(out, 0.1:0.1:0.9) 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173 ``` Subexpression: d = Beta(2, 3); Evaluated output: ERROR: UndefVarError: `Beta` not defined Stacktrace: [1] top-level scope @ none:1 Expected output: diff = Warning: Diff output requires color. ERROR: UndefVarError: `Beta` not defined Stacktrace: [1] top-level scope @ none:1

Check failure on line 360 in src/statsfuns.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:360-380 ```jldoctest julia> using Distributions, BayesBase julia> d = Beta(2, 3); julia> inplace = BayesBase.InplaceLogpdf((out, sample) -> logpdf!(out, d, sample)); julia> out = zeros(9); julia> inplace(out, 0.1:0.1:0.9) 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173 ``` Subexpression: inplace(out, 0.1:0.1:0.9) Evaluated output: ERROR: UndefVarError: `logpdf!` not defined Stacktrace: [1] (::var"#1#2")(out::Vector{Float64}, sample::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) @ Main ./none:1 [2] (::BayesBase.InplaceLogpdf{var"#1#2"})(out::Vector{Float64}, x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) @ BayesBase ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:387 [3] top-level scope @ none:1 Expected output: 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173 diff = Warning: Diff output requires color. 9-element Vector{Float64}: -0.028399474521697776 0.42918163472548043 0.5675839575845996 0.5469646703818638 0.4054651081081646 0.14149956227369964 -0.2797139028026039 -0.9571127263944104 -2.2256240518579173ERROR: UndefVarError: `logpdf!` not defined Stacktrace: [1] (::var"#1#2")(out::Vector{Float64}, sample::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) @ Main ./none:1 [2] (::BayesBase.InplaceLogpdf{var"#1#2"})(out::Vector{Float64}, x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) @ BayesBase ~/work/BayesBase.jl/BayesBase.jl/src/statsfuns.jl:387 [3] top-level scope @ none:1
julia> using Distributions, BayesBase

julia> d = Beta(2, 3);

julia> inplace = BayesBase.InplaceLogpdf((out, sample) -> logpdf!(out, d, sample));

julia> out = zeros(9);

julia> inplace(out, 0.1:0.1:0.9)
9-element Vector{Float64}:
-0.028399474521697776
0.42918163472548043
0.5675839575845996
0.5469646703818638
0.4054651081081646
0.14149956227369964
-0.2797139028026039
-0.9571127263944104
-2.2256240518579173
```
"""
struct InplaceLogpdf{F}
logpdf!::F
end

function (inplace::InplaceLogpdf)(out, x)
inplace.logpdf!(out, x)
return out
end

function Base.convert(::Type{InplaceLogpdf}, something)
return InplaceLogpdf((out, x) -> map!(something, out, x))
end

function Base.convert(::Type{InplaceLogpdf}, inplace::InplaceLogpdf)
return inplace
end
60 changes: 58 additions & 2 deletions test/statsfuns_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ end
end
end

@testitem "dtanh" begin
@testitem "dtanh" begin
for T in (Float32, Float64, BigFloat)
foreach(rand(T, 10)) do number
@test dtanh(number) ≈ 1 - tanh(number) ^ 2
@test dtanh(number) ≈ 1 - tanh(number)^2
end
end
end
Expand Down Expand Up @@ -87,6 +87,62 @@ end

@test float(convert(CountingReal, r)) ≈ zero(T)
@test float(convert(CountingReal{Float64}, r)) ≈ zero(Float64)
end
end

@testitem "InplaceLogpdf" begin
import BayesBase: InplaceLogpdf
using Distributions, LinearAlgebra, StableRNGs

@testset "Vector based samples" begin
distribution = Beta(10, 10)
fn = (x) -> logpdf(distribution, x)
inplacefn = convert(InplaceLogpdf, fn)

@test fn !== inplacefn

rng = StableRNG(42)
samples = rand(rng, distribution, 100)
evaluated = map(fn, samples)

container = similar(evaluated)
inplacefn(container, samples)

@test evaluated == container
end

@testset "Matrix based samples" begin
distribution = MvNormal(ones(2), ones(2))
fn = (x) -> logpdf(distribution, x)
inplacefn = convert(InplaceLogpdf, fn)

@test inplacefn !== fn

rng = StableRNG(42)
samples = rand(rng, distribution, 100)
evaluated = map(fn, eachcol(samples))

container = similar(evaluated)
inplacefn(container, eachcol(samples))

@test evaluated == container
end

@testset "Do not convert already inplace version" begin
distribution = MvNormal(ones(2), ones(2))
fn = InplaceLogpdf((out, x) -> logpdf!(out, distribution, x))
inplacefn = convert(InplaceLogpdf, fn)

@test inplacefn === fn

rng = StableRNG(42)
samples = rand(rng, distribution, 100)
evaluated = zeros(100)
fn(evaluated, eachcol(samples))

container = similar(evaluated)
inplacefn(container, eachcol(samples))

@test evaluated == container
end
end
Loading