diff --git a/Project.toml b/Project.toml index 3567d7606..30d1bf1e8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.18" +version = "0.4.19" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index a79fed05b..ebc7592a3 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -387,23 +387,22 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) # unit test for this, but integration testing seems to catch it in multiple places. stmt == PiNode(nothing, Union{}) && return ad_stmt_info(line, nothing, stmt, nothing) - # Assume that the PiNode contains active data -- it's hard to see why a PiNode would be - # created for e.g. a constant. Error if code is encountered where this doesn't hold. - is_active(stmt.val) || unhandled_feature("PiNode: $stmt") - - # Get the primal type of this line, and the rdata refs for the `val` of this `PiNode` - # and this line itself. - P = get_primal_type(info, line) - val_rdata_ref_id = get_rev_data_id(info, stmt.val) - output_rdata_ref_id = get_rev_data_id(info, line) - - # Assemble the above lines and construct reverse-pass. - return ad_stmt_info( - line, - nothing, - PiNode(__inc(stmt.val), fcodual_type(_type(stmt.typ))), - Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id), - ) + if is_active(stmt.val) + # Get the primal type of this line, and the rdata refs for the `val` of this + # `PiNode` and this line itself. + P = get_primal_type(info, line) + val_rdata_ref_id = get_rev_data_id(info, stmt.val) + output_rdata_ref_id = get_rev_data_id(info, line) + fwds = PiNode(__inc(stmt.val), fcodual_type(_type(stmt.typ))) + rvs = Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id) + else + # If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to + # do on the reverse-pass. + fwds = PiNode(const_codual(stmt.val, info), fcodual_type(_type(stmt.typ))) + rvs = nothing + end + + return ad_stmt_info(line, nothing, fwds, rvs) end @inline function __pi_rvs!(::Type{P}, val_rdata_ref::Ref, output_rdata_ref::Ref) where {P} diff --git a/test/integration_testing/distributions.jl b/test/integration_testing/distributions.jl index 72ba86908..79bbbf288 100644 --- a/test/integration_testing/distributions.jl +++ b/test/integration_testing/distributions.jl @@ -227,16 +227,22 @@ _pdmat(A) = PDMat(_sym(A) + 5I) ), ( :none, - "allocs Normal", + "truncated Normal", (a, b, x) -> logpdf(truncated(Normal(), a, b), x), (-0.3, 0.3, 0.1), ), ( :none, - "allocs Uniform", + "truncated Uniform", (a, b, α, β, x) -> logpdf(truncated(Uniform(α, β), a, b), x), (0.1, 0.9, -0.1, 1.1, 0.4), ), + ( + :none, + "left-truncated Beta", + (a, α, β, x) -> logpdf(truncated(Beta(α, β), lower=a), x), + (0.1, 1.1, 1.3, 0.4), + ), (:none, "Dirichlet", (a, x) -> logpdf(Dirichlet(a), [x, 1-x]), ([1.5, 1.1], 0.6)), ( :none, diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 4e79b3ec6..263b1b040 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -137,11 +137,34 @@ end stmt_info = make_ad_stmts!(PiNode(nothing, Union{}), line, info) @test stmt_info isa ADStmtInfo end - @testset "unhandled case" begin - @test_throws( - Mooncake.UnhandledLanguageFeatureException, - make_ad_stmts!(PiNode(5.0, Float64), ID(), info), - ) + @testset "π (nothing, Nothing)" begin + stmt_info = make_ad_stmts!(PiNode(nothing, Nothing), id_line_1, info) + @test stmt_info isa ADStmtInfo + fwds_stmt = only(stmt_info.fwds)[2].stmt + @test fwds_stmt isa PiNode + @test fwds_stmt.val == CoDual(nothing, NoFData()) + @test fwds_stmt.typ == CoDual{Nothing, NoFData} + @test only(stmt_info.rvs)[2].stmt === nothing + end + @testset "π (nothing, CC.Const(nothing))" begin + node = PiNode(nothing, CC.Const(nothing)) + stmt_info = make_ad_stmts!(node, id_line_1, info) + @test stmt_info isa ADStmtInfo + fwds_stmt = only(stmt_info.fwds)[2].stmt + @test fwds_stmt isa PiNode + @test fwds_stmt.val == CoDual(nothing, NoFData()) + @test fwds_stmt.typ == CoDual{Nothing, NoFData} + @test only(stmt_info.rvs)[2].stmt === nothing + end + @testset "π (GlobalRef, Type)" begin + node = PiNode(GlobalRef(S2SGlobals, :const_float), Any) + stmt_info = make_ad_stmts!(node, id_line_1, info) + @test stmt_info isa ADStmtInfo + fwds_stmt = only(stmt_info.fwds)[2].stmt + @test fwds_stmt isa PiNode + @test fwds_stmt.val == CoDual(5.0, NoFData()) + @test fwds_stmt.typ == CoDual + @test only(stmt_info.rvs)[2].stmt === nothing end @testset "sharpen type of ID" begin line = id_line_1