Skip to content

Commit

Permalink
Merge pull request #78 from zeyus/master
Browse files Browse the repository at this point in the history
Fix num_warmups arg, added test for cmdstan cmdline args
  • Loading branch information
goedman authored Aug 8, 2024
2 parents e9e446f + 1d4aab6 commit 40f7f61
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
tmp
Manifest.toml
deps/data/bridgestan/bin
/.vscode
7 changes: 3 additions & 4 deletions src/stanrun/cmdline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ function cmdline(m::SampleModel, id; kwargs...)
cmd = ``
# Handle the model name field for unix and windows
cmd = `$(m.exec_path)`

if m.use_cpp_chains
cmd = :num_threads in keys(kwargs) ? `$cmd num_threads=$(m.num_threads)` : `$cmd`
cmd = `$cmd method=sample num_chains=$(m.num_cpp_chains)`
Expand All @@ -25,7 +24,7 @@ function cmdline(m::SampleModel, id; kwargs...)
end

cmd = :num_samples in keys(kwargs) ? `$cmd num_samples=$(m.num_samples)` : `$cmd`
cmd = :num_warmup in keys(kwargs) ? `$cmd num_warmup=$(m.num_warmups)` : `$cmd`
cmd = :num_warmups in keys(kwargs) ? `$cmd num_warmup=$(m.num_warmups)` : `$cmd`
cmd = :save_warmup in keys(kwargs) ? `$cmd save_warmup=$(m.save_warmup)` : `$cmd`
cmd = :save_warmup in keys(kwargs) ? `$cmd thin=$(m.thin)` : `$cmd`
cmd = `$cmd adapt engaged=$(m.engaged)`
Expand All @@ -38,8 +37,8 @@ function cmdline(m::SampleModel, id; kwargs...)
cmd = :window in keys(kwargs) ? `$cmd window=$(m.window)` : `$cmd`
cmd = :save_metric in keys(kwargs) ? `$cmd save_metric=$(m.save_metric)` : `$cmd`

# Algorithm section
cmd = :algorithm in keys(kwargs) ? `$cmd algorithm=$(string(m.algorithm))` : `$cmd`
# Algorithm section, algorithm can only be HMC
cmd = `$cmd algorithm=$(string(m.algorithm))`
if m.algorithm == :hmc
cmd = :engine in keys(kwargs) ? `$cmd engine=$(string(m.engine))` : `$cmd`
if m.engine == :nuts
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ if haskey(ENV, "CMDSTAN") || haskey(ENV, "JULIA_CMDSTAN_HOME")
"test_basic_runs/test_bernoulli_dict.jl",
"test_basic_runs/test_bernoulli_array_dict_1.jl",
"test_basic_runs/test_bernoulli_array_dict_2.jl",
"test_basic_runs/test_parse_interpolate.jl"
"test_basic_runs/test_parse_interpolate.jl",
"test_basic_runs/test_cmdstan_args.jl",
]

@testset "Bernoulli basic run tests" begin
Expand Down Expand Up @@ -242,4 +243,3 @@ if haskey(ENV, "CMDSTAN") || haskey(ENV, "JULIA_CMDSTAN_HOME")
else
println("\nCMDSTAN and JULIA_CMDSTAN_HOME not set. Skipping tests")
end

59 changes: 59 additions & 0 deletions test/test_basic_runs/test_cmdstan_args.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using StanSample, Test

ProjDir = @__DIR__
cd(ProjDir) # do

bernoulli_model = "
data {
int<lower=1> N;
array[N] int<lower=0,upper=1> y;
}
parameters {
real<lower=0,upper=1> theta;
}
model {
theta ~ beta(1,1);
y ~ bernoulli(theta);
}
";

sm = SampleModel("bernoulli", bernoulli_model)
observeddata = Dict("N" => 10, "y" => [0, 1, 0, 1, 0, 0, 0, 0, 0, 1])
rc = stan_sample(
sm;
data=observeddata,
num_samples=13,
num_warmups=17,
save_warmup=true,
num_chains=1,
sig_figs=2,
stepsize=0.7,
)

@test success(rc)
samples = read_samples(sm, :array)

shape = size(samples)
# number of samples, number of chains, number of parameters
@test shape == (30, 1, 1)

# read the log file
f = open(sm.log_file[1], "r")
# remove leading whitespace and chop off the "(default)" suffix
config = [chopsuffix(lstrip(x), r"\s+\(default\)$"i) for x in eachline(f) if length(x) > 0]
close(f)
# check that the config is as expected

required_entries = [
"method = sample",
"num_samples = 13",
"num_warmup = 17",
"save_warmup = true",
"num_chains = 1",
"sig_figs = 2",
"stepsize = 0.7",
]

for entry in required_entries
@test entry in config
end

0 comments on commit 40f7f61

Please sign in to comment.