Skip to content

Commit

Permalink
Primitive implementation for serialization (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 authored Dec 12, 2024
1 parent 5e5ac1d commit f51144c
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 5 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.7.5"
version = "0.8.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -22,6 +22,7 @@ MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -67,6 +68,7 @@ MacroTools = "0.5"
MetaGraphsNext = "0.6, 0.7"
OrderedCollections = "1"
PDMats = "0.10, 0.11"
Serialization = "1.9.0"
SpecialFunctions = "2"
StaticArrays = "1.9"
Statistics = "1.9"
Expand Down
3 changes: 2 additions & 1 deletion src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using LogDensityProblems, LogDensityProblemsAD
using MacroTools
using OrderedCollections: OrderedDict
using Random
using Serialization: Serialization
using StaticArrays

import Base: ==, hash, Symbol, size
Expand Down Expand Up @@ -172,7 +173,7 @@ function compile(model_def::Expr, data::NamedTuple, initial_params::NamedTuple=N
values(eval_env),
),
)
return BUGSModel(g, nonmissing_eval_env, initial_params)
return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)
end

"""
Expand Down
37 changes: 34 additions & 3 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ end
The `BUGSModel` object is used for inference and represents the output of compilation. It implements the
[`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface.
"""
struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV} <:
AbstractBUGSModel
struct BUGSModel{
base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV,data_T
} <: AbstractBUGSModel
" Indicates whether the model parameters are in the transformed space. "
transformed::Bool

Expand All @@ -74,6 +75,10 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,

"If not `Nothing`, the model is a conditioned model; otherwise, it's the model returned by `compile`."
base_model::base_model_T

# for serialization, save the original model definition and data
model_def::Expr
data::data_T
end

function Base.show(io::IO, model::BUGSModel)
Expand Down Expand Up @@ -137,7 +142,9 @@ variables(model::BUGSModel) = collect(labels(model.g))
function BUGSModel(
g::BUGSGraph,
evaluation_env::NamedTuple,
initial_params::NamedTuple=NamedTuple();
model_def::Expr,
data::NamedTuple,
initial_params::NamedTuple=NamedTuple(),
is_transformed::Bool=true,
)
flattened_graph_node_data = FlattenedGraphNodeData(g)
Expand Down Expand Up @@ -199,6 +206,8 @@ function BUGSModel(
flattened_graph_node_data,
g,
nothing,
model_def,
data,
)
end

Expand All @@ -220,9 +229,31 @@ function BUGSModel(
FlattenedGraphNodeData(g, sorted_nodes),
g,
isnothing(model.base_model) ? model : model.base_model,
model.model_def,
model.data,
)
end

function Serialization.serialize(s::Serialization.AbstractSerializer, model::BUGSModel)
Serialization.writetag(s.io, Serialization.OBJECT_TAG)
Serialization.serialize(s, typeof(model))
Serialization.serialize(s, model.transformed)
Serialization.serialize(s, model.model_def)
Serialization.serialize(s, model.data)
Serialization.serialize(s, model.evaluation_env)
return nothing
end

function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{<:BUGSModel})
model_def = Serialization.deserialize(s)
data = Serialization.deserialize(s)
evaluation_env = Serialization.deserialize(s)
transformed = Serialization.deserialize(s)
# use evaluation_env as initialization to restore the values
model = compile(model_def, data, evaluation_env)
return settrans(model, transformed)
end

"""
initialize!(model::BUGSModel, initial_params::NamedTuple)
Expand Down
29 changes: 29 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
@testset "serialization" begin
(; model_def, data) = JuliaBUGS.BUGSExamples.rats
model = compile(model_def, data)
serialize("m.jls", model)
deserialized = deserialize("m.jls")
@testset "test values are correctly restored" begin
for vn in MetaGraphsNext.labels(model.g)
@test isequal(
get(model.evaluation_env, vn), get(deserialized.evaluation_env, vn)
)
end

@test model.transformed == deserialized.transformed
@test model.untransformed_param_length == deserialized.untransformed_param_length
@test model.transformed_param_length == deserialized.transformed_param_length
@test all(
model.untransformed_var_lengths[k] == deserialized.untransformed_var_lengths[k]
for k in keys(model.untransformed_var_lengths)
)
@test all(
model.transformed_var_lengths[k] == deserialized.transformed_var_lengths[k] for
k in keys(model.transformed_var_lengths)
)
@test Set(model.parameters) == Set(deserialized.parameters)
# skip testing g
@test model.model_def === deserialized.model_def
end
end

@testset "controlling sampling behavior for conditioned variables" begin
model_def = @bugs begin
x ~ Normal(0, 1)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using MacroTools
using MCMCChains
using Random
using ReverseDiff
using Serialization

AbstractMCMC.setprogress!(false)

Expand Down

2 comments on commit f51144c

@sunxd3
Copy link
Member Author

@sunxd3 sunxd3 commented on f51144c Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/121294

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.0 -m "<description of version>" f51144cfe8c7f33cc6bcd4ed7299d2b18c23f247
git push origin v0.8.0

Please sign in to comment.