diff --git a/docs/Project.toml b/docs/Project.toml index 2c96d67..7300733 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,4 @@ [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/src/index.md b/docs/src/index.md index d006112..3fc733e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -114,4 +114,5 @@ BayesBase.distribution_typewrapper BayesBase.CountingReal BayesBase.Infinity BayesBase.MinusInfinity +BayesBase.InplaceLogpdf ``` diff --git a/src/BayesBase.jl b/src/BayesBase.jl index c4a3489..f805ea5 100644 --- a/src/BayesBase.jl +++ b/src/BayesBase.jl @@ -30,6 +30,7 @@ import Distributions: pdf!, cdf, logpdf, + logpdf!, logdetcov, VariateForm, ValueSupport, @@ -58,6 +59,7 @@ export failprob, pdf!, cdf, logpdf, + logpdf!, logdetcov, VariateForm, ValueSupport, diff --git a/src/statsfuns.jl b/src/statsfuns.jl index 2e6a363..220e1a5 100644 --- a/src/statsfuns.jl +++ b/src/statsfuns.jl @@ -371,4 +371,73 @@ function mcov!( end return Z +end + +""" +InplaceLogpdf(logpdf!) + +Wraps a `logpdf!` function in a type that can later on be used for dispatch. +The sole purpose of this wrapper type is to allow for in-place logpdf operation on a batch of samples. +Accepts a function `logpdf!` that takes two arguments: `out` and `sample` and writes the logpdf of the sample to the `out` array. +A regular `logpdf` function can be converted to `logpdf!` by using `convert(InplaceLogpdf, logpdf)`. + +```jldoctest +julia> using Distributions, BayesBase + +julia> d = Beta(2, 3); + +julia> inplace = convert(BayesBase.InplaceLogpdf, (sample) -> logpdf(d, sample)); + +julia> out = zeros(9); + +julia> inplace(out, 0.1:0.1:0.9) +9-element Vector{Float64}: + -0.028399474521697776 + 0.42918163472548043 + 0.5675839575845996 + 0.5469646703818638 + 0.4054651081081646 + 0.14149956227369964 + -0.2797139028026039 + -0.9571127263944104 + -2.2256240518579173 +``` + +```jldoctest +julia> using Distributions, BayesBase + +julia> d = Beta(2, 3); + +julia> inplace = BayesBase.InplaceLogpdf((out, sample) -> logpdf!(out, d, sample)); + +julia> out = zeros(9); + +julia> inplace(out, 0.1:0.1:0.9) +9-element Vector{Float64}: + -0.028399474521697776 + 0.42918163472548043 + 0.5675839575845996 + 0.5469646703818638 + 0.4054651081081646 + 0.14149956227369964 + -0.2797139028026039 + -0.9571127263944104 + -2.2256240518579173 +``` +""" +struct InplaceLogpdf{F} + logpdf!::F +end + +function (inplace::InplaceLogpdf)(out, x) + inplace.logpdf!(out, x) + return out +end + +function Base.convert(::Type{InplaceLogpdf}, something) + return InplaceLogpdf((out, x) -> map!(something, out, x)) +end + +function Base.convert(::Type{InplaceLogpdf}, inplace::InplaceLogpdf) + return inplace end \ No newline at end of file diff --git a/test/statsfuns_tests.jl b/test/statsfuns_tests.jl index 7080aca..1e52e4a 100644 --- a/test/statsfuns_tests.jl +++ b/test/statsfuns_tests.jl @@ -115,4 +115,70 @@ end @report_opt mcov!(Z, X, Y; tmp1=tmp1, tmp2=tmp2, tmp3=tmp3, tmp4=tmp4) @test @allocated(mcov!(Z, X, Y; tmp1=tmp1, tmp2=tmp2, tmp3=tmp3, tmp4=tmp4)) === 0 end +end + +@testitem "InplaceLogpdf" begin + import BayesBase: InplaceLogpdf + using Distributions, LinearAlgebra, StableRNGs + + @testset "Vector based samples" begin + distribution = Beta(10, 10) + fn = (x) -> logpdf(distribution, x) + inplacefn = convert(InplaceLogpdf, fn) + + @test fn !== inplacefn + + rng = StableRNG(42) + samples = rand(rng, distribution, 100) + evaluated = map(fn, samples) + + container = similar(evaluated) + inplacefn(container, samples) + + @test evaluated == container + end + + @testset "Matrix based samples" begin + distribution = MvNormal(ones(2), ones(2)) + fn = (x) -> logpdf(distribution, x) + inplacefn = convert(InplaceLogpdf, fn) + + @test inplacefn !== fn + + rng = StableRNG(42) + samples = rand(rng, distribution, 100) + evaluated = map(fn, eachcol(samples)) + + container = similar(evaluated) + inplacefn(container, eachcol(samples)) + + @test evaluated == container + end + + @testset "Do not convert already inplace version" begin + distribution = MvNormal(ones(2), ones(2)) + fn = InplaceLogpdf((out, x) -> logpdf!(out, distribution, x)) + inplacefn = convert(InplaceLogpdf, fn) + + @test inplacefn === fn + + rng = StableRNG(42) + samples = rand(rng, distribution, 100) + evaluated = zeros(100) + fn(evaluated, eachcol(samples)) + + container = similar(evaluated) + inplacefn(container, eachcol(samples)) + + @test evaluated == container + end + + @testset "Shouldn't allocate anything for simple `logpdf!`" begin + fn = InplaceLogpdf((out, x) -> out .= log.(x)) + samples = 1:10 + out = zeros(10) + fn(out, samples) + @test out == log.(samples) + @test @allocated(fn(out, samples)) === 0 + end end \ No newline at end of file