Skip to content

Commit

Permalink
Document StanSample usage with ArviZ (#261)
Browse files Browse the repository at this point in the history
* Use recent cmdstan version in CI

* Document usage of Stan.jl's native conversion

* Remove extra line

* Switch tests to use StanSample

* Update docs project

* Increment patch number

* Loosen StanSample compat for tests

* Bump minimum StanSample version

* Update quickstart.jl

* Skip cmdstan test on windows

* Skip from_cmdstan test on older Julia versions

* Remove paren
sethaxen authored Jan 16, 2023
1 parent 9abdb06 commit e91e1b6
Showing 6 changed files with 46 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ on:

env:
PYTHON: "Conda" # use Julia's packaged Conda build for installing packages
CMDSTAN_VERSION: "2.25.0"
CMDSTAN_VERSION: "2.31.0"
CMDSTAN_PATH: "${{ GITHUB.WORKSPACE }}/.cmdstan/"

jobs:
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArviZ"
uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
authors = ["Seth Axen <[email protected]>"]
version = "0.8.0"
version = "0.8.1"

[deps]
ArviZExampleData = "2f96bb34-afd9-46ae-bcd0-9b2d4372fe3c"
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
ArviZ = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -15,13 +14,13 @@ PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
SampleChains = "754583d1-7fc4-4dab-93b5-5eaca5c9622e"
StanSample = "c1514b29-d3a0-5178-b312-660c88baa699"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AlgebraOfGraphics = "0.6.9"
ArviZ = "0.8"
CairoMakie = "0.8.9, 0.9, 0.10"
CmdStan = "6.0"
DataFrames = "1"
DimensionalData = "0.23, 0.24"
Distributions = "0.25"
@@ -32,4 +31,5 @@ PlutoUI = "0.7"
PyCall = "1.0"
PyPlot = "2.0"
SampleChains = "0.5"
StanSample = "7.0.1"
Turing = "0.21, 0.22, 0.23"
66 changes: 28 additions & 38 deletions docs/src/quickstart.jl
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ using InteractiveUtils
using Pkg, InteractiveUtils

# ╔═╡ ac57d957-0bdd-457a-ac15-9a4f94f0c785
# Remove this cell to use release versions of dependencies
# Remove this cell to use release versions of dependencies
# hideall
let
docs_dir = dirname(@__DIR__)
@@ -20,7 +20,7 @@ let
end;

# ╔═╡ 467c2d13-6bfe-4feb-9626-fb14796168aa
using ArviZ, CmdStan, Distributions, LinearAlgebra, PyPlot, Random, Turing
using ArviZ, Distributions, LinearAlgebra, PyPlot, Random, StanSample, Turing

# ╔═╡ 56a39a90-0594-48f4-ba04-f7b612019cd1
using PlutoUI
@@ -30,7 +30,7 @@ md"""
# [ArviZ.jl Quickstart](#quickstart)
!!! note
This tutorial is adapted from [ArviZ's quickstart](https://python.arviz.org/en/latest/getting_started/Introduction.html).
"""

@@ -49,7 +49,7 @@ ArviZ.use_style("arviz-darkgrid")
md"""
## [Get started with plotting](#Get-started-with-plotting)
ArviZ.jl is designed to be used with libraries like [CmdStan](https://github.com/StanJulia/CmdStan.jl), [Turing.jl](https://turing.ml), and [Soss.jl](https://github.com/cscherrer/Soss.jl) but works fine with raw arrays.
ArviZ.jl is designed to be used with libraries like [Stan](https://github.com/StanJulia/Stan.jl), [Turing.jl](https://turing.ml), and [Soss.jl](https://github.com/cscherrer/Soss.jl) but works fine with raw arrays.
"""

# ╔═╡ efb3f0af-9fac-48d8-bbb2-2dd6ebd5e4f6
@@ -286,11 +286,11 @@ end

# ╔═╡ 98acc304-22e3-4e6b-a2f4-d22f6847145b
md"""
## [Plotting with CmdStan.jl outputs](#Plotting-with-CmdStan.jl-outputs)
## [Plotting with Stan.jl outputs](#Plotting-with-Stan.jl-outputs)
CmdStan.jl and StanSample.jl also default to producing `Chains` outputs, and we can easily plot these chains.
StanSample.jl comes with built-in support for producing `InferenceData` outputs.
Here is the same centered eight schools model:
Here is the same centered eight schools model in Stan:
"""

# ╔═╡ b46af168-1ce3-4058-a014-b66c645a6e0d
@@ -326,44 +326,36 @@ begin
"""

schools_data = Dict("J" => J, "y" => y, "sigma" => σ)
stan_chns = mktempdir() do path
stan_model = Stanmodel(;
model=schools_code,
name="schools",
nchains,
num_warmup=ndraws_warmup,
idata_stan = mktempdir() do path
stan_model = SampleModel("schools", schools_code, path)
_ = stan_sample(
stan_model;
data=schools_data,
num_chains=nchains,
num_warmups=ndraws_warmup,
num_samples=ndraws,
output_format=:mcmcchains,
random=CmdStan.Random(28983),
tmpdir=path,
seed=28983,
summary=false,
)
return StanSample.inferencedata(
stan_model;
posterior_predictive_var=:y_hat,
observed_data=(; y),
log_likelihood_var=:log_lik,
coords=(; school=schools),
dims=NamedTuple(
k => (:school,) for k in (:y, :sigma, :theta, :log_lik, :y_hat)
),
)
_, chns, _ = stan(stan_model, schools_data; summary=false)
return chns
end
end;
end

# ╔═╡ ab145e41-b230-4cad-bef5-f31e0e0770d4
begin
plot_density(stan_chns; var_names=(:mu, :tau))
plot_density(idata_stan; var_names=(:mu, :tau))
gcf()
end

# ╔═╡ ffc7730c-d861-48e8-b173-b03e0542f32b
md"""
Again, converting to `InferenceData`, we can get much richer labelling and mixing of data.
Note that we're using the same [`from_cmdstan`](https://julia.arviz.org/stable/reference/#ArviZ.from_cmdstan) function used by ArviZ to process cmdstan output files, but through the power of dispatch in Julia, if we pass a `Chains` object, it instead uses ArviZ.jl's overloads, which forward to `from_mcmcchains`.
"""

# ╔═╡ 020cbdc0-a0a2-4d20-838f-c99b541d5832
idata_stan = from_cmdstan(
stan_chns;
posterior_predictive=:y_hat,
observed_data=(; y),
log_likelihood=:log_lik,
coords=(; school=schools),
dims=NamedTuple(k => (:school,) for k in (:y, :sigma, :theta, :log_lik, :y_hat)),
)

# ╔═╡ e44b260c-9d2f-43f8-a64b-04245a0a5658
md"""Here is a plot showing where the Hamiltonian sampler had divergences:"""

@@ -430,8 +422,6 @@ with_terminal(versioninfo)
# ╟─98acc304-22e3-4e6b-a2f4-d22f6847145b
# ╠═b46af168-1ce3-4058-a014-b66c645a6e0d
# ╠═ab145e41-b230-4cad-bef5-f31e0e0770d4
# ╟─ffc7730c-d861-48e8-b173-b03e0542f32b
# ╠═020cbdc0-a0a2-4d20-838f-c99b541d5832
# ╟─e44b260c-9d2f-43f8-a64b-04245a0a5658
# ╠═5070bbbc-68d2-49b8-bd91-456dc0da4573
# ╟─ac2b4378-bd1c-4164-af05-d9a35b1bb08f
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
@@ -10,11 +9,11 @@ PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SampleChains = "754583d1-7fc4-4dab-93b5-5eaca5c9622e"
SampleChainsDynamicHMC = "6d9fd711-e8b2-4778-9c70-c1dfb499d4c4"
StanSample = "c1514b29-d3a0-5178-b312-660c88baa699"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CmdStan = "5.2.3, 6.0"
DataFrames = "0.20, 0.21, 0.22, 1.0"
DimensionalData = "0.20, 0.21, 0.22, 0.23, 0.24"
MCMCChains = "0.3.15, 0.4, 1.0, 2.0, 3.0, 4.0, 5.0"
@@ -24,3 +23,4 @@ PyCall = "1.91.2"
PyPlot = "2.8.2"
SampleChains = "0.5"
SampleChainsDynamicHMC = "0.3"
StanSample = "6, 7"
26 changes: 12 additions & 14 deletions test/test_mcmcchains.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test
using ArviZ
using DimensionalData
using MCMCChains: MCMCChains
using CmdStan, OrderedCollections
using OrderedCollections, StanSample

const noncentered_schools_stan_model = """
data {
@@ -49,19 +52,14 @@ function makechains(nvars::Int, args...; kwargs...)
return makechains(names, args...; kwargs...)
end

function cmdstan_noncentered_schools(data, draws, chains; tmpdir=joinpath(pwd(), "tmp"))
function stan_noncentered_schools(data, draws, chains; tmpdir=mktempdir())
model_name = "school8"
stan_model = Stanmodel(;
name=model_name,
model=noncentered_schools_stan_model,
nchains=chains,
num_warmup=draws,
num_samples=draws,
output_format=:mcmcchains,
tmpdir,
stan_model = SampleModel(model_name, noncentered_schools_stan_model, tmpdir)
_ = stan_sample(
stan_model; data=data, num_chains=chains, num_samples=draws, summary=false
)
rc, chns, cnames = stan(stan_model, data; summary=false)
outfiles = ["$(tmpdir)/$(model_name)_samples_$(i).csv" for i in 1:chains]
chns = read_samples(stan_model, :mcmcchains; include_internals=true)
outfiles = ["$(stan_model.output_base)_chain_$(i).csv" for i in 1:chains]
return (model=stan_model, files=outfiles, chains=chns)
end

@@ -339,10 +337,10 @@ end
@test ArviZ.summary(chn) !== nothing
end

@testset "from_cmdstan" begin
Sys.iswindows() || VERSION < v"1.8" || @testset "from_cmdstan" begin
data = noncentered_schools_data()
mktempdir() do path
output = cmdstan_noncentered_schools(data, 500, 4; tmpdir=path)
output = stan_noncentered_schools(data, 500, 4; tmpdir=path)
posterior_predictive = prior_predictive = [:y_hat]
log_likelihood = :log_lik
coords = (school=1:8,)

2 comments on commit e91e1b6

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75805

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.1 -m "<description of version>" e91e1b6c63fc41ed53af1ab3e11e1d8995642e6d
git push origin v0.8.1

Please sign in to comment.