Skip to content

Commit

Permalink
Revert "Merge pull request #126 from TuringLang/torfjelde/init-params…
Browse files Browse the repository at this point in the history
…-fix"

This reverts commit 084f809, reversing
changes made to caeade2.
  • Loading branch information
torfjelde committed Oct 24, 2023
1 parent 8919bfa commit 2be66bc
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 117 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.6.0"
version = "4.4.2"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand All @@ -30,10 +30,9 @@ Transducers = "0.4.30"
julia = "1.6"

[extras]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["FillArrays", "IJulia", "Statistics", "Test"]
test = ["IJulia", "Statistics", "Test"]
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Common keyword arguments for regular and parallel sampling are:
There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.

Expand Down
54 changes: 29 additions & 25 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ function mcmcsample(
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -364,10 +364,10 @@ function mcmcsample(
_sampler,
N;
progress=false,
init_params=if init_params === nothing
init_params=if _init_params === nothing
nothing
else
init_params[chainidx]
_init_params[chainidx]
end,
kwargs...,
)
Expand Down Expand Up @@ -410,9 +410,6 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -502,9 +499,6 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -538,21 +532,31 @@ end
tighten_eltype(x) = x
tighten_eltype(x::Vector{Any}) = map(identity, x)

@nospecialize check_initial_params(x, n) = throw(
ArgumentError(
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
),
)
"""
_first_or_nothing(x, n::Int)
check_initial_params(::Nothing, n) = nothing
function check_initial_params(x::AbstractArray, n)
if length(x) != n
throw(
ArgumentError(
"incorrect number of initial parameters (expected $n, received $(length(x))"
),
)
end
Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.
return nothing
If `x !== nothing`, then `x` has to contain at least `n` elements.
"""
function _first_or_nothing(x, n::Int)
y = _first(x, n)
length(y) == n || throw(
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
)
return y
end
_first_or_nothing(::Nothing, ::Int) = nothing

# `first(x, n::Int)` requires Julia 1.6
function _first(x, n::Int)
@static if VERSION >= v"1.6.0-DEV.431"
first(x, n)
else
if x isa AbstractVector
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
else
collect(Iterators.take(x, n))
end
end
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using IJulia
using LogDensityProblems
using LoggingExtras: TeeLogger, EarlyFilteredLogger
using TerminalLoggers: TerminalLogger
using FillArrays: FillArrays
using Transducers

using Distributed
Expand Down
105 changes: 18 additions & 87 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,17 @@
end

# initial parameters
nchains = 100
init_params = [(b=randn(), a=rand()) for _ in 1:nchains]
init_params = [(b=randn(), a=rand()) for _ in 1:100]
chains = sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
100;
progress=false,
init_params=init_params,
)
@test length(chains) == nchains
@test length(chains) == 100
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -185,36 +184,14 @@
MySampler(),
MCMCThreads(),
3,
nchains;
100;
progress=false,
init_params=FillArrays.Fill(init_params, nchains),
init_params=Iterators.repeated(init_params),
)
@test length(chains) == nchains
@test length(chains) == 100
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains - 1),
)
end

@testset "Multicore sampling" begin
Expand Down Expand Up @@ -297,18 +274,17 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
nchains = 100
init_params = [(a=randn(), b=rand()) for _ in 1:nchains]
init_params = [(a=randn(), b=rand()) for _ in 1:100]
chains = sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
100;
progress=false,
init_params=init_params,
)
@test length(chains) == nchains
@test length(chains) == 100
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -320,37 +296,15 @@
MySampler(),
MCMCDistributed(),
3,
nchains;
100;
progress=false,
init_params=FillArrays.Fill(init_params, nchains),
init_params=Iterators.repeated(init_params),
)
@test length(chains) == nchains
@test length(chains) == 100
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains - 1),
)

# Remove workers
rmprocs(pids...)
end
Expand Down Expand Up @@ -406,18 +360,17 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
nchains = 100
init_params = [(a=rand(), b=randn()) for _ in 1:nchains]
init_params = [(a=rand(), b=randn()) for _ in 1:100]
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
100;
progress=false,
init_params=init_params,
)
@test length(chains) == nchains
@test length(chains) == 100
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -429,36 +382,14 @@
MySampler(),
MCMCSerial(),
3,
nchains;
100;
progress=false,
init_params=FillArrays.Fill(init_params, nchains),
init_params=Iterators.repeated(init_params),
)
@test length(chains) == nchains
@test length(chains) == 100
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains - 1),
)
end

@testset "Ensemble sampling: Reproducibility" begin
Expand Down

0 comments on commit 2be66bc

Please sign in to comment.