From 635b53e8e2e42a7af90e42071e7ce45eb108f23f Mon Sep 17 00:00:00 2001 From: Rob J Goedman Date: Wed, 11 Jan 2023 09:59:51 -0700 Subject: [PATCH] Rel 7.0.1 - Set types for non-float columns of sample_stats in :namedtuple, :namedtuples, :dataframe, :dataframes in read_samples() --- Project.toml | 2 +- src/utils/dataframes.jl | 18 +++++ src/utils/namedtuples.jl | 8 ++ test/runtests.jl | 80 ++++++++++--------- .../test_bernoulli_keyedarray_01.jl | 4 +- test/test_inferencedata/test_inferencedata.jl | 3 + 6 files changed, 73 insertions(+), 42 deletions(-) diff --git a/Project.toml b/Project.toml index 7046af9..2481c1a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StanSample" uuid = "c1514b29-d3a0-5178-b312-660c88baa699" authors = ["Rob J Goedman "] -version = "7.0.0" +version = "7.0.1" [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" diff --git a/src/utils/dataframes.jl b/src/utils/dataframes.jl index 0b8d6a9..9166706 100644 --- a/src/utils/dataframes.jl +++ b/src/utils/dataframes.jl @@ -18,6 +18,15 @@ function convert_a3d(a3d_array, cnames, ::Val{:dataframe}) for j in 2:size(a3d_array, 3) df = vcat(df, DataFrame(a3d_array[:, :, j], Symbol.(cnames))) end + + for name in names(df) + if name in ["treedepth__", "n_leapfrog__"] + df[!, name] = Int.(df[:, name]) + elseif name == "divergent__" + df[!, name] = Bool.(df[:, name]) + end + end + df end @@ -35,6 +44,15 @@ function convert_a3d(a3d_array, cnames, ::Val{:dataframes}) dfa = Vector{DataFrame}(undef, size(a3d_array, 3)) for j in 1:size(a3d_array, 3) dfa[j] = DataFrame(a3d_array[:, :, j], Symbol.(cnames)) + + for name in names(dfa[j]) + if name in ["treedepth__", "n_leapfrog__"] + dfa[j][!, name] = Int.(dfa[j][:, name]) + elseif name == "divergent__" + dfa[j][!, name] = Bool.(dfa[j][:, name]) + end + end + end dfa diff --git a/src/utils/namedtuples.jl b/src/utils/namedtuples.jl index 09b4e9b..70533c6 100644 --- a/src/utils/namedtuples.jl +++ b/src/utils/namedtuples.jl @@ -42,6 +42,14 @@ function extract(chns::Array{Float64,3}, cnames::Vector{String}; permute_dims=fa end end + for name in keys(ex_dict) + if name in [:treedepth__, :n_leapfrog__] + ex_dict[name] = convert(Matrix{Int}, ex_dict[name]) + elseif name == :divergent__ + ex_dict[name] = convert(Matrix{Bool}, ex_dict[name]) + end + end + return (;ex_dict...) end diff --git a/test/runtests.jl b/test/runtests.jl index eb7dc71..2671118 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,60 +3,62 @@ using StanSample, Test import CompatHelperLocal as CHL CHL.@check() -if haskey(ENV, "JULIA_CMDSTAN_HOME") || haskey(ENV, "CMDSTAN") +if haskey(ENV, "CMDSTAN") || haskey(ENV, "JULIA_CMDSTAN_HOME") TestDir = @__DIR__ tmpdir = mktempdir() - test_bernoulli = [ - "test_keyedarray/test_bernoulli_keyedarray_01.jl", - "test_keyedarray/test_keyedarray.jl", - ] - - @testset "Bernoulli array tests" begin - include(joinpath(TestDir, "test_bernoulli/test_bernoulli_keyedarray_01.jl")) - - if success(rc) + println("\nTesting test_bernoulli/test_bernoulli_keyedarray_01.jl") + include(joinpath(TestDir, "test_bernoulli/test_bernoulli_keyedarray_01.jl")) - sdf = read_summary(sm, :dataframe) - @test sdf[sdf.parameters .== :theta, :mean][1] ≈ 0.33 rtol=0.05 + if success(rc) - (samples, parameters) = read_samples(sm, :array; - return_parameters=true) - @test size(samples) == (1000, 1, 6) - @test length(parameters) == 1 + sdf = read_summary(sm, :dataframe) + sdf |> display - (samples, parameters) = read_samples(sm, :array; - return_parameters=true, include_internals=true) - @test size(samples) == (1000, 8, 6) - @test length(parameters) == 8 + @test sdf[sdf.parameters .== :theta, :mean][1] ≈ 0.33 rtol=0.05 - samples = read_samples(sm, :array; - include_internals=true) - @test size(samples) == (1000, 8, 6) + (samples, parameters) = read_samples(sm, :array; + return_parameters=true) + @test size(samples) == (1000, 1, 6) + @test length(parameters) == 1 - samples = read_samples(sm, :array) - @test size(samples) == (1000, 1, 6) - end + (samples, parameters) = read_samples(sm, :array; + return_parameters=true, include_internals=true) + @test size(samples) == (1000, 8, 6) + @test length(parameters) == 8 - include(joinpath(TestDir, "test_bernoulli/test_bernoulli_array_02.jl")) - if success(rc) - sdf = read_summary(sm) - @test sdf[sdf.parameters .== :theta, :mean][1] ≈ 0.33 rtol=0.05 + samples = read_samples(sm, :array; + include_internals=true) + @test size(samples) == (1000, 8, 6) - (samples, parameters) = read_samples(sm, :array; - return_parameters=true) - @test size(samples) == (250, 1, 4) - @test length(parameters) == 1 + samples = read_samples(sm, :array) + @test size(samples) == (1000, 1, 6) + + end + + println("\nTesting test_bernoulli/test_bernoulli_array_02.jl") + include(joinpath(TestDir, "test_bernoulli/test_bernoulli_array_02.jl")) + + if success(rc) + sdf = read_summary(sm) + @test sdf[sdf.parameters .== :theta, :mean][1] ≈ 0.33 rtol=0.05 + + (samples, parameters) = read_samples(sm, :array; + return_parameters=true) + @test size(samples) == (250, 1, 4) + @test length(parameters) == 1 + + samples = read_samples(sm, :array; + include_internals=true) + @test size(samples) == (250, 8, 4) + + end - samples = read_samples(sm, :array; - include_internals=true) - @test size(samples) == (250, 8, 4) - end end - test_bridgestan = [ + test_bridgestan = [ "test_bridgestan/test_bridgestan.jl", ] diff --git a/test/test_bernoulli/test_bernoulli_keyedarray_01.jl b/test/test_bernoulli/test_bernoulli_keyedarray_01.jl index b75bd03..17d92ee 100644 --- a/test/test_bernoulli/test_bernoulli_keyedarray_01.jl +++ b/test/test_bernoulli/test_bernoulli_keyedarray_01.jl @@ -1,6 +1,5 @@ ######### StanSample Bernoulli example ########### -using AxisKeys using StanSample bernoulli_model = " @@ -20,9 +19,10 @@ model { bernoulli_data = Dict("N" => 10, "y" => [0, 1, 0, 1, 0, 0, 0, 0, 0, 1]) sm = SampleModel("bernoulli", bernoulli_model); - rc = stan_sample(sm; data=bernoulli_data, num_chains=6); if success(rc) chns = read_samples(sm) end + +chns |> display diff --git a/test/test_inferencedata/test_inferencedata.jl b/test/test_inferencedata/test_inferencedata.jl index ccf2107..f2f0ea4 100644 --- a/test/test_inferencedata/test_inferencedata.jl +++ b/test/test_inferencedata/test_inferencedata.jl @@ -75,3 +75,6 @@ posterior_schools = DataFrame(idata.posterior) idata |> display +sample_stats_schools = DataFrame(idata.sample_stats) + +sample_stats_schools[1:10, :] |> display