From 938a69dfabb7089f52700ef01fa3a9f2d667b7d1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 16 Jan 2025 14:49:18 +0000 Subject: [PATCH] Restrict `values_as_in_model` API (#778) --- Project.toml | 2 +- src/values_as_in_model.jl | 41 ++++++++++++++------------------------- test/model.jl | 16 --------------- 3 files changed, 16 insertions(+), 43 deletions(-) diff --git a/Project.toml b/Project.toml index 2bf60214f..fb9a1c55f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.33.1" +version = "0.34.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 16556ee8c..ca8cc1cb3 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -19,9 +19,9 @@ wants to extract the realization of a model in a constrained space. # Fields $(TYPEDFIELDS) """ -struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext +struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext "values that are extracted from the model" - values::T + values::OrderedDict "whether to extract variables on the LHS of :=" include_colon_eq::Bool "child context" @@ -114,34 +114,32 @@ function dot_tilde_assume( end """ - values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext]) - values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext]) + values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) Get the values of `varinfo` as they would be seen in the model. -If no `varinfo` is provided, then this is effectively the same as -[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref). +More specifically, this method attempts to extract the realization _as seen in +the model_. For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a +realization that is compatible with `truncated(Normal(); lower=0)` -- i.e. one +where the value of `x[1]` is positive -- regardless of whether `varinfo` is +working in unconstrained space. -More specifically, this method attempts to extract the realization _as seen in the model_. -For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible -with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained -space. - -Hence this method is a "safe" way of obtaining realizations in constrained space at the cost -of additional model evaluations. +Hence this method is a "safe" way of obtaining realizations in constrained +space at the cost of additional model evaluations. # Arguments - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. - `varinfo::AbstractVarInfo`: variable information to use for the extraction. -- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context` - will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`. +- `context::AbstractContext`: base context to use for the extraction. Defaults + to `DynamicPPL.DefaultContext()`. # Examples ## When `VarInfo` fails -The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables. +The following demonstrates a common pitfall when working with [`VarInfo`](@ref) +and constrained variables. ```jldoctest julia> using Distributions, StableRNGs @@ -191,19 +189,10 @@ true function values_as_in_model( model::Model, include_colon_eq::Bool, - varinfo::AbstractVarInfo=VarInfo(), + varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext(), ) context = ValuesAsInModelContext(include_colon_eq, context) evaluate!!(model, varinfo, context) return context.values end -function values_as_in_model( - rng::Random.AbstractRNG, - model::Model, - include_colon_eq::Bool, - varinfo::AbstractVarInfo=VarInfo(), - context::AbstractContext=DefaultContext(), -) - return values_as_in_model(model, true, varinfo, SamplingContext(rng, context)) -end diff --git a/test/model.jl b/test/model.jl index eb8d6a932..45c770cc4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -429,22 +429,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end - - @testset "check that sampling obeys rng if passed" begin - @model function f() - x ~ Normal(0) - return y ~ Normal(x) - end - model = f() - # Call values_as_in_model with the rng - values = values_as_in_model(Random.Xoshiro(43), model, false) - # Check that they match the values that would be used if vi was seeded - # with that seed instead - expected_vi = VarInfo(Random.Xoshiro(43), model) - for vn in keys(values) - @test values[vn] == expected_vi[vn] - end - end end @testset "Erroneous model call" begin