From f51144cfe8c7f33cc6bcd4ed7299d2b18c23f247 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 12 Dec 2024 09:51:38 -0800 Subject: [PATCH] Primitive implementation for serialization (#258) --- Project.toml | 4 +++- src/JuliaBUGS.jl | 3 ++- src/model.jl | 37 ++++++++++++++++++++++++++++++++++--- test/model.jl | 29 +++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 69 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 4d338dc9..5acb519c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.7.5" +version = "0.8.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -22,6 +22,7 @@ MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -67,6 +68,7 @@ MacroTools = "0.5" MetaGraphsNext = "0.6, 0.7" OrderedCollections = "1" PDMats = "0.10, 0.11" +Serialization = "1.9.0" SpecialFunctions = "2" StaticArrays = "1.9" Statistics = "1.9" diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index d2add3de..b6b36501 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -12,6 +12,7 @@ using LogDensityProblems, LogDensityProblemsAD using MacroTools using OrderedCollections: OrderedDict using Random +using Serialization: Serialization using StaticArrays import Base: ==, hash, Symbol, size @@ -172,7 +173,7 @@ function compile(model_def::Expr, data::NamedTuple, initial_params::NamedTuple=N values(eval_env), ), ) - return BUGSModel(g, nonmissing_eval_env, initial_params) + return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params) end """ diff --git a/src/model.jl b/src/model.jl index a37bf34f..25f3b20d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -48,8 +48,9 @@ end The `BUGSModel` object is used for inference and represents the output of compilation. It implements the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface. """ -struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV} <: - AbstractBUGSModel +struct BUGSModel{ + base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV,data_T +} <: AbstractBUGSModel " Indicates whether the model parameters are in the transformed space. " transformed::Bool @@ -74,6 +75,10 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple, "If not `Nothing`, the model is a conditioned model; otherwise, it's the model returned by `compile`." base_model::base_model_T + + # for serialization, save the original model definition and data + model_def::Expr + data::data_T end function Base.show(io::IO, model::BUGSModel) @@ -137,7 +142,9 @@ variables(model::BUGSModel) = collect(labels(model.g)) function BUGSModel( g::BUGSGraph, evaluation_env::NamedTuple, - initial_params::NamedTuple=NamedTuple(); + model_def::Expr, + data::NamedTuple, + initial_params::NamedTuple=NamedTuple(), is_transformed::Bool=true, ) flattened_graph_node_data = FlattenedGraphNodeData(g) @@ -199,6 +206,8 @@ function BUGSModel( flattened_graph_node_data, g, nothing, + model_def, + data, ) end @@ -220,9 +229,31 @@ function BUGSModel( FlattenedGraphNodeData(g, sorted_nodes), g, isnothing(model.base_model) ? model : model.base_model, + model.model_def, + model.data, ) end +function Serialization.serialize(s::Serialization.AbstractSerializer, model::BUGSModel) + Serialization.writetag(s.io, Serialization.OBJECT_TAG) + Serialization.serialize(s, typeof(model)) + Serialization.serialize(s, model.transformed) + Serialization.serialize(s, model.model_def) + Serialization.serialize(s, model.data) + Serialization.serialize(s, model.evaluation_env) + return nothing +end + +function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{<:BUGSModel}) + model_def = Serialization.deserialize(s) + data = Serialization.deserialize(s) + evaluation_env = Serialization.deserialize(s) + transformed = Serialization.deserialize(s) + # use evaluation_env as initialization to restore the values + model = compile(model_def, data, evaluation_env) + return settrans(model, transformed) +end + """ initialize!(model::BUGSModel, initial_params::NamedTuple) diff --git a/test/model.jl b/test/model.jl index 9e44f01a..80767b26 100644 --- a/test/model.jl +++ b/test/model.jl @@ -1,3 +1,32 @@ +@testset "serialization" begin + (; model_def, data) = JuliaBUGS.BUGSExamples.rats + model = compile(model_def, data) + serialize("m.jls", model) + deserialized = deserialize("m.jls") + @testset "test values are correctly restored" begin + for vn in MetaGraphsNext.labels(model.g) + @test isequal( + get(model.evaluation_env, vn), get(deserialized.evaluation_env, vn) + ) + end + + @test model.transformed == deserialized.transformed + @test model.untransformed_param_length == deserialized.untransformed_param_length + @test model.transformed_param_length == deserialized.transformed_param_length + @test all( + model.untransformed_var_lengths[k] == deserialized.untransformed_var_lengths[k] + for k in keys(model.untransformed_var_lengths) + ) + @test all( + model.transformed_var_lengths[k] == deserialized.transformed_var_lengths[k] for + k in keys(model.transformed_var_lengths) + ) + @test Set(model.parameters) == Set(deserialized.parameters) + # skip testing g + @test model.model_def === deserialized.model_def + end +end + @testset "controlling sampling behavior for conditioned variables" begin model_def = @bugs begin x ~ Normal(0, 1) diff --git a/test/runtests.jl b/test/runtests.jl index fab36c47..a08b4260 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,6 +21,7 @@ using MacroTools using MCMCChains using Random using ReverseDiff +using Serialization AbstractMCMC.setprogress!(false)