Skip to content

Commit

Permalink
Replace instances of @submodel with to_submodel() (#751)
Browse files Browse the repository at this point in the history
* Replace remaining instances of @SubModel

* Implement tilde_assume!! for Sampleable / pointwise contexts

* Fix typos in test

* Bump patch version

* Re-add some minimal tests for deprecated @SubModel

* Format

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
penelopeysm and github-actions[bot] authored Dec 17, 2024
1 parent 8972b98 commit 681e472
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.32.0"
version = "0.32.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ adds the `Prefix` to all parameters.
This context is useful in nested models to ensure that the names of the parameters are
unique.
See also: [`@submodel`](@ref)
See also: [`to_submodel`](@ref)
"""
struct PrefixContext{Prefix,C} <: AbstractContext
context::C
Expand Down
20 changes: 20 additions & 0 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,26 @@ function _pointwise_tilde_observe(
end
end

# Note on submodels (penelopeysm)
#
# We don't need to overload tilde_observe!! for Sampleables (yet), because it
# is currently not possible to evaluate a model with a Sampleable on the RHS
# of an observe statement.
#
# Note that calling tilde_assume!! on a Sampleable does not necessarily imply
# that there are no observe statements inside the Sampleable. There could well
# be likelihood terms in there, which must be included in the returned logp.
# See e.g. the `demo_dot_assume_observe_submodel` demo model.
#
# This is handled by passing the same context to rand_like!!, which figures out
# which terms to include using the context, and also mutates the context and vi
# appropriately. Thus, we don't need to check against _include_prior(context)
# here.
function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi)
value, vi = DynamicPPL.rand_like!!(right, context, vi)
return value, vi
end

function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi)
!_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi))
value, logp, vi = tilde_assume(context.context, right, vn, vi)
Expand Down
9 changes: 6 additions & 3 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ end

@model function demo_assume_submodel_observe_index_literal()
# Submodel prior
@submodel s, m = _prior_dot_assume()
priors ~ to_submodel(_prior_dot_assume(), false)
s, m = priors
1.5 ~ Normal(m[1], sqrt(s[1]))
2.0 ~ Normal(m[2], sqrt(s[2]))

Expand All @@ -462,7 +463,7 @@ function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
end

@model function _likelihood_mltivariate_observe(s, m, x)
@model function _likelihood_multivariate_observe(s, m, x)
return x ~ MvNormal(m, Diagonal(s))
end

Expand All @@ -475,7 +476,9 @@ end
m .~ Normal.(0, sqrt.(s))

# Submodel likelihood
@submodel _likelihood_mltivariate_observe(s, m, x)
# With to_submodel, we have to have a left-hand side variable to
# capture the result, so we just use a dummy variable
_ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x))

return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
end
Expand Down
52 changes: 11 additions & 41 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,34 +382,13 @@ module Issue537 end
@test demo2()() == 42
end

@testset "@submodel is deprecated" begin
@model inner() = x ~ Normal()
@model outer() = @submodel x = inner()
@test_logs(
(
:warn,
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
),
outer()()
)

@model outer_with_prefix() = @submodel prefix = "sub" x = inner()
@test_logs(
(
:warn,
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
),
outer_with_prefix()()
)
end

@testset "submodel" begin
@testset "to_submodel" begin
# No prefix, 1 level.
@model function demo1(x)
return x ~ Normal()
end
@model function demo2(x, y)
@submodel demo1(x)
_ignore ~ to_submodel(demo1(x), false)
return y ~ Uniform()
end
# No observation.
Expand Down Expand Up @@ -441,7 +420,7 @@ module Issue537 end

# Check values makes sense.
@model function demo3(x, y)
@submodel demo1(x)
_ignore ~ to_submodel(demo1(x), false)
return y ~ Normal(x)
end
m = demo3(1000.0, missing)
Expand All @@ -453,12 +432,10 @@ module Issue537 end
x ~ Normal()
return x
end

@model function demo_useval(x, y)
@submodel prefix = "sub1" x1 = demo_return(x)
@submodel prefix = "sub2" x2 = demo_return(y)

return z ~ Normal(x1 + x2 + 100, 1.0)
sub1 ~ to_submodel(demo_return(x))
sub2 ~ to_submodel(demo_return(y))
return z ~ Normal(sub1 + sub2 + 100, 1.0)
end
m = demo_useval(missing, missing)
vi = VarInfo(m)
Expand All @@ -472,21 +449,18 @@ module Issue537 end
@model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV}
η ~ MvNormal(zeros(num_steps), I)
δ = sqrt(1 - α^2)

x = TV(undef, num_steps)
x[1] = η[1]
@inbounds for t in 2:num_steps
x[t] = @. α * x[t - 1] + δ * η[t]
end

return @. μ + σ * x
end

@model function demo(y)
α ~ Uniform()
μ ~ Normal()
σ ~ truncated(Normal(), 0, Inf)

num_steps = length(y[1])
num_obs = length(y)
@inbounds for i in 1:num_obs
Expand Down Expand Up @@ -613,14 +587,11 @@ module Issue537 end
@model demo() = x ~ Normal()
retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext())

# Return-value when using `@submodel`
# Return-value when using `to_submodel`
@model inner() = x ~ Normal()
# Without assignment.
@model outer() = @submodel inner()
@test outer()() isa Real

# With assignment.
@model outer() = @submodel x = inner()
@model function outer()
return _ignore ~ to_submodel(inner())
end
@test outer()() isa Real

# Edge-cases.
Expand Down Expand Up @@ -720,8 +691,7 @@ module Issue537 end
return (; x, y)
end
@model function demo_tracked_submodel()
@submodel (x, y) = demo_tracked()
return (; x, y)
return vals ~ to_submodel(demo_tracked(), false)
end
for model in [demo_tracked(), demo_tracked_submodel()]
# Make sure it's runnable and `y` is present in the return-value.
Expand Down
57 changes: 57 additions & 0 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
@testset "deprecated" begin
@testset "@submodel" begin
@testset "is deprecated" begin
@model inner() = x ~ Normal()
@model outer() = @submodel x = inner()
@test_logs(
(
:warn,
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
),
outer()()
)

@model outer_with_prefix() = @submodel prefix = "sub" x = inner()
@test_logs(
(
:warn,
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
),
outer_with_prefix()()
)
end

@testset "prefixing still works correctly" begin
@model inner() = x ~ Normal()
@model function outer()
a = @submodel inner()
b = @submodel prefix = "sub" inner()
return a, b
end
@test outer()() isa Tuple{Float64,Float64}
vi = VarInfo(outer())
@test @varname(x) in keys(vi)
@test @varname(var"sub.x") in keys(vi)
end

@testset "logp is still accumulated properly" begin
@model inner_assume() = x ~ Normal()
@model inner_observe(x, y) = y ~ Normal(x)
@model function outer(b)
a = @submodel inner_assume()
@submodel inner_observe(a, b)
end
y_val = 1.0
model = outer(y_val)
@test model() == y_val

x_val = 1.5
vi = VarInfo(outer(y_val))
DynamicPPL.setindex!!(vi, x_val, @varname(x))
@test logprior(model, vi) logpdf(Normal(), x_val)
@test loglikelihood(model, vi) logpdf(Normal(x_val), y_val)
@test logjoint(model, vi)
logpdf(Normal(), x_val) + logpdf(Normal(x_val), y_val)
end
end
end
10 changes: 5 additions & 5 deletions test/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(
model, example_values...
)
logp_true = logprior(model, vi)
logprior_true = logprior(model, vi)

# Compute the pointwise loglikelihoods.
lls = pointwise_loglikelihoods(model, vi)
Expand All @@ -30,18 +30,18 @@
lps_prior = pointwise_prior_logdensities(model, vi)
@test :x DynamicPPL.getsym.(keys(lps_prior))
logp = sum(sum, values(lps_prior))
@test logp logp_true
@test logp logprior_true

# Compute both likelihood and logdensity of prior
# using the default DefaultContex
# using the default DefaultContext
lps = pointwise_logdensities(model, vi)
logp = sum(sum, values(lps))
@test logp (logp_true + loglikelihood_true)
@test logp (logprior_true + loglikelihood_true)

# Test that modifications of Setup are picked up
lps = pointwise_logdensities(model, vi, mod_ctx2)
logp = sum(sum, values(lps))
@test logp (logp_true + loglikelihood_true) * 1.2 * 1.4
@test logp (logprior_true + loglikelihood_true) * 1.2 * 1.4
end
end

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ include("test_util.jl")
include("serialization.jl")
include("pointwise_logdensities.jl")
include("lkj.jl")
include("deprecated.jl")
end

if GROUP == "All" || GROUP == "Group2"
Expand Down

0 comments on commit 681e472

Please sign in to comment.