diff --git a/.gitignore b/.gitignore index c74728d..cbc33f2 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ tmp Manifest.toml deps/data/bridgestan/bin +/.vscode diff --git a/src/stanrun/cmdline.jl b/src/stanrun/cmdline.jl index 47eab1b..3f56fe8 100644 --- a/src/stanrun/cmdline.jl +++ b/src/stanrun/cmdline.jl @@ -16,7 +16,6 @@ function cmdline(m::SampleModel, id; kwargs...) cmd = `` # Handle the model name field for unix and windows cmd = `$(m.exec_path)` - if m.use_cpp_chains cmd = :num_threads in keys(kwargs) ? `$cmd num_threads=$(m.num_threads)` : `$cmd` cmd = `$cmd method=sample num_chains=$(m.num_cpp_chains)` @@ -25,7 +24,7 @@ function cmdline(m::SampleModel, id; kwargs...) end cmd = :num_samples in keys(kwargs) ? `$cmd num_samples=$(m.num_samples)` : `$cmd` - cmd = :num_warmup in keys(kwargs) ? `$cmd num_warmup=$(m.num_warmups)` : `$cmd` + cmd = :num_warmups in keys(kwargs) ? `$cmd num_warmup=$(m.num_warmups)` : `$cmd` cmd = :save_warmup in keys(kwargs) ? `$cmd save_warmup=$(m.save_warmup)` : `$cmd` cmd = :save_warmup in keys(kwargs) ? `$cmd thin=$(m.thin)` : `$cmd` cmd = `$cmd adapt engaged=$(m.engaged)` @@ -38,8 +37,8 @@ function cmdline(m::SampleModel, id; kwargs...) cmd = :window in keys(kwargs) ? `$cmd window=$(m.window)` : `$cmd` cmd = :save_metric in keys(kwargs) ? `$cmd save_metric=$(m.save_metric)` : `$cmd` - # Algorithm section - cmd = :algorithm in keys(kwargs) ? `$cmd algorithm=$(string(m.algorithm))` : `$cmd` + # Algorithm section, algorithm can only be HMC + cmd = `$cmd algorithm=$(string(m.algorithm))` if m.algorithm == :hmc cmd = :engine in keys(kwargs) ? `$cmd engine=$(string(m.engine))` : `$cmd` if m.engine == :nuts diff --git a/test/runtests.jl b/test/runtests.jl index 5d2b2e3..ad04d4f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -118,7 +118,8 @@ if haskey(ENV, "CMDSTAN") || haskey(ENV, "JULIA_CMDSTAN_HOME") "test_basic_runs/test_bernoulli_dict.jl", "test_basic_runs/test_bernoulli_array_dict_1.jl", "test_basic_runs/test_bernoulli_array_dict_2.jl", - "test_basic_runs/test_parse_interpolate.jl" + "test_basic_runs/test_parse_interpolate.jl", + "test_basic_runs/test_cmdstan_args.jl", ] @testset "Bernoulli basic run tests" begin @@ -242,4 +243,3 @@ if haskey(ENV, "CMDSTAN") || haskey(ENV, "JULIA_CMDSTAN_HOME") else println("\nCMDSTAN and JULIA_CMDSTAN_HOME not set. Skipping tests") end - diff --git a/test/test_basic_runs/test_cmdstan_args.jl b/test/test_basic_runs/test_cmdstan_args.jl new file mode 100644 index 0000000..3113d57 --- /dev/null +++ b/test/test_basic_runs/test_cmdstan_args.jl @@ -0,0 +1,59 @@ +using StanSample, Test + +ProjDir = @__DIR__ +cd(ProjDir) # do + + bernoulli_model = " + data { + int N; + array[N] int y; + } + parameters { + real theta; + } + model { + theta ~ beta(1,1); + y ~ bernoulli(theta); + } + "; + + sm = SampleModel("bernoulli", bernoulli_model) + observeddata = Dict("N" => 10, "y" => [0, 1, 0, 1, 0, 0, 0, 0, 0, 1]) + rc = stan_sample( + sm; + data=observeddata, + num_samples=13, + num_warmups=17, + save_warmup=true, + num_chains=1, + sig_figs=2, + stepsize=0.7, + ) + + @test success(rc) + samples = read_samples(sm, :array) + + shape = size(samples) + # number of samples, number of chains, number of parameters + @test shape == (30, 1, 1) + + # read the log file + f = open(sm.log_file[1], "r") + # remove leading whitespace and chop off the "(default)" suffix + config = [chopsuffix(lstrip(x), r"\s+\(default\)$"i) for x in eachline(f) if length(x) > 0] + close(f) + # check that the config is as expected + + required_entries = [ + "method = sample", + "num_samples = 13", + "num_warmup = 17", + "save_warmup = true", + "num_chains = 1", + "sig_figs = 2", + "stepsize = 0.7", + ] + + for entry in required_entries + @test entry in config + end