diff --git a/src/core/ad.jl b/src/core/ad.jl index 1d466f93b..5fa0dec1c 100644 --- a/src/core/ad.jl +++ b/src/core/ad.jl @@ -125,3 +125,11 @@ Tracker.@grad function binomlogpdf(n::Int, p::Tracker.TrackedReal, x::Int) return binomlogpdf(n, Tracker.data(p), x), Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing) end + + +import StatsFuns: poislogpdf +poislogpdf(v::Tracker.TrackedReal, x::Int) = Tracker.track(poislogpdf, v, x) +Tracker.@grad function poislogpdf(v::Tracker.TrackedReal, x::Int) + return poislogpdf(Tracker.data(v), x), + Δ->(Δ * (x/v - 1), nothing) +end diff --git a/test/ad.jl/AD_compatibility_with_distributions.jl b/test/ad.jl/AD_compatibility_with_distributions.jl index 7ef15d54e..d65a5b985 100644 --- a/test/ad.jl/AD_compatibility_with_distributions.jl +++ b/test/ad.jl/AD_compatibility_with_distributions.jl @@ -46,3 +46,21 @@ let atol=1e-8, ) end + +let + foo = p->poislogpdf(1, p) + @test isapprox( + Tracker.gradient(foo, 0.5)[1], + central_fdm(5, 1)(foo, 0.5); + rtol=1e-8, + atol=1e-8, + ) + + bar = p->logpdf(Poisson(1), 3) + @test isapprox( + Tracker.gradient(bar, 0.5)[1], + central_fdm(5, 1)(bar, 0.5); + rtol=1e-8, + atol=1e-8, + ) +end