Skip to content

Commit

Permalink
Rel 7.0.1 - Set types for non-float columns of sample_stats in :named…
Browse files Browse the repository at this point in the history
…tuple, :namedtuples, :dataframe, :dataframes in read_samples()
  • Loading branch information
goedman committed Jan 11, 2023
1 parent e3ecd65 commit 635b53e
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StanSample"
uuid = "c1514b29-d3a0-5178-b312-660c88baa699"
authors = ["Rob J Goedman <[email protected]>"]
version = "7.0.0"
version = "7.0.1"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Expand Down
18 changes: 18 additions & 0 deletions src/utils/dataframes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ function convert_a3d(a3d_array, cnames, ::Val{:dataframe})
for j in 2:size(a3d_array, 3)
df = vcat(df, DataFrame(a3d_array[:, :, j], Symbol.(cnames)))
end

for name in names(df)
if name in ["treedepth__", "n_leapfrog__"]
df[!, name] = Int.(df[:, name])
elseif name == "divergent__"
df[!, name] = Bool.(df[:, name])
end
end

df
end

Expand All @@ -35,6 +44,15 @@ function convert_a3d(a3d_array, cnames, ::Val{:dataframes})
dfa = Vector{DataFrame}(undef, size(a3d_array, 3))
for j in 1:size(a3d_array, 3)
dfa[j] = DataFrame(a3d_array[:, :, j], Symbol.(cnames))

for name in names(dfa[j])
if name in ["treedepth__", "n_leapfrog__"]
dfa[j][!, name] = Int.(dfa[j][:, name])
elseif name == "divergent__"
dfa[j][!, name] = Bool.(dfa[j][:, name])
end
end

end

dfa
Expand Down
8 changes: 8 additions & 0 deletions src/utils/namedtuples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ function extract(chns::Array{Float64,3}, cnames::Vector{String}; permute_dims=fa
end
end

for name in keys(ex_dict)
if name in [:treedepth__, :n_leapfrog__]
ex_dict[name] = convert(Matrix{Int}, ex_dict[name])
elseif name == :divergent__
ex_dict[name] = convert(Matrix{Bool}, ex_dict[name])
end
end

return (;ex_dict...)
end

Expand Down
80 changes: 41 additions & 39 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,62 @@ using StanSample, Test
import CompatHelperLocal as CHL
CHL.@check()

if haskey(ENV, "JULIA_CMDSTAN_HOME") || haskey(ENV, "CMDSTAN")
if haskey(ENV, "CMDSTAN") || haskey(ENV, "JULIA_CMDSTAN_HOME")

TestDir = @__DIR__
tmpdir = mktempdir()

test_bernoulli = [
"test_keyedarray/test_bernoulli_keyedarray_01.jl",
"test_keyedarray/test_keyedarray.jl",
]


@testset "Bernoulli array tests" begin
include(joinpath(TestDir, "test_bernoulli/test_bernoulli_keyedarray_01.jl"))

if success(rc)
println("\nTesting test_bernoulli/test_bernoulli_keyedarray_01.jl")
include(joinpath(TestDir, "test_bernoulli/test_bernoulli_keyedarray_01.jl"))

sdf = read_summary(sm, :dataframe)
@test sdf[sdf.parameters .== :theta, :mean][1] 0.33 rtol=0.05
if success(rc)

(samples, parameters) = read_samples(sm, :array;
return_parameters=true)
@test size(samples) == (1000, 1, 6)
@test length(parameters) == 1
sdf = read_summary(sm, :dataframe)
sdf |> display

(samples, parameters) = read_samples(sm, :array;
return_parameters=true, include_internals=true)
@test size(samples) == (1000, 8, 6)
@test length(parameters) == 8
@test sdf[sdf.parameters .== :theta, :mean][1] 0.33 rtol=0.05

samples = read_samples(sm, :array;
include_internals=true)
@test size(samples) == (1000, 8, 6)
(samples, parameters) = read_samples(sm, :array;
return_parameters=true)
@test size(samples) == (1000, 1, 6)
@test length(parameters) == 1

samples = read_samples(sm, :array)
@test size(samples) == (1000, 1, 6)
end
(samples, parameters) = read_samples(sm, :array;
return_parameters=true, include_internals=true)
@test size(samples) == (1000, 8, 6)
@test length(parameters) == 8

include(joinpath(TestDir, "test_bernoulli/test_bernoulli_array_02.jl"))
if success(rc)
sdf = read_summary(sm)
@test sdf[sdf.parameters .== :theta, :mean][1] 0.33 rtol=0.05
samples = read_samples(sm, :array;
include_internals=true)
@test size(samples) == (1000, 8, 6)

(samples, parameters) = read_samples(sm, :array;
return_parameters=true)
@test size(samples) == (250, 1, 4)
@test length(parameters) == 1
samples = read_samples(sm, :array)
@test size(samples) == (1000, 1, 6)

end

println("\nTesting test_bernoulli/test_bernoulli_array_02.jl")
include(joinpath(TestDir, "test_bernoulli/test_bernoulli_array_02.jl"))

if success(rc)
sdf = read_summary(sm)
@test sdf[sdf.parameters .== :theta, :mean][1] 0.33 rtol=0.05

(samples, parameters) = read_samples(sm, :array;
return_parameters=true)
@test size(samples) == (250, 1, 4)
@test length(parameters) == 1

samples = read_samples(sm, :array;
include_internals=true)
@test size(samples) == (250, 8, 4)

end

samples = read_samples(sm, :array;
include_internals=true)
@test size(samples) == (250, 8, 4)
end
end

test_bridgestan = [
test_bridgestan = [
"test_bridgestan/test_bridgestan.jl",
]

Expand Down
4 changes: 2 additions & 2 deletions test/test_bernoulli/test_bernoulli_keyedarray_01.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
######### StanSample Bernoulli example ###########

using AxisKeys
using StanSample

bernoulli_model = "
Expand All @@ -20,9 +19,10 @@ model {
bernoulli_data = Dict("N" => 10, "y" => [0, 1, 0, 1, 0, 0, 0, 0, 0, 1])

sm = SampleModel("bernoulli", bernoulli_model);

rc = stan_sample(sm; data=bernoulli_data, num_chains=6);

if success(rc)
chns = read_samples(sm)
end

chns |> display
3 changes: 3 additions & 0 deletions test/test_inferencedata/test_inferencedata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ posterior_schools = DataFrame(idata.posterior)

idata |> display

sample_stats_schools = DataFrame(idata.sample_stats)

sample_stats_schools[1:10, :] |> display

0 comments on commit 635b53e

Please sign in to comment.