Skip to content

Commit

Permalink
Merge pull request #616 from TuringLang/csp/poisson-fix
Browse files Browse the repository at this point in the history
Enable backprop through poislogpdf
  • Loading branch information
yebai authored Dec 7, 2018
2 parents df476ad + 5f81d3f commit 6b26d1d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,11 @@ Tracker.@grad function binomlogpdf(n::Int, p::Tracker.TrackedReal, x::Int)
return binomlogpdf(n, Tracker.data(p), x),
Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing)
end


import StatsFuns: poislogpdf
poislogpdf(v::Tracker.TrackedReal, x::Int) = Tracker.track(poislogpdf, v, x)
Tracker.@grad function poislogpdf(v::Tracker.TrackedReal, x::Int)
return poislogpdf(Tracker.data(v), x),
Δ->* (x/v - 1), nothing)
end
18 changes: 18 additions & 0 deletions test/ad.jl/AD_compatibility_with_distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,21 @@ let
atol=1e-8,
)
end

let
foo = p->poislogpdf(1, p)
@test isapprox(
Tracker.gradient(foo, 0.5)[1],
central_fdm(5, 1)(foo, 0.5);
rtol=1e-8,
atol=1e-8,
)

bar = p->logpdf(Poisson(1), 3)
@test isapprox(
Tracker.gradient(bar, 0.5)[1],
central_fdm(5, 1)(bar, 0.5);
rtol=1e-8,
atol=1e-8,
)
end

0 comments on commit 6b26d1d

Please sign in to comment.