Skip to content

Commit

Permalink
Fix discard_initial, and add support for discard_initial and `thi…
Browse files Browse the repository at this point in the history
…nning` to iterator and transducer (#102)

* Fix `discard_initial`, and add support for `discard_initial` and `thinning` to iterator and transducer

* Fix test errors on Julia < 1.6

* Only enable progress logging on Julia < 1.6

* Use different seed

* Update api.md

* Update api.md

* Update sample.jl

* Use `==` instead of `===`
  • Loading branch information
devmotion authored Jun 1, 2022
1 parent 650d9e1 commit 8d7f22f
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.1.0"
version = "4.1.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
6 changes: 4 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ AbstractMCMC.MCMCSerial

## Common keyword arguments

Common keyword arguments for regular and parallel sampling (not supported by the iterator and transducer)
are:
Common keyword arguments for regular and parallel sampling are:
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging
- `chain_type` (default: `Any`): determines the type of the returned chain
- `callback` (default: `nothing`): if `callback !== nothing`, then
Expand All @@ -53,6 +52,9 @@ are:
- `discard_initial` (default: `0`): number of initial samples that are discarded
- `thinning` (default: `1`): factor by which to thin samples.

!!! info
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).

There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
Expand Down
4 changes: 2 additions & 2 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function mcmcsample(
sample, state = step(rng, model, sampler; kwargs...)

# Discard initial samples.
for i in 1:(discard_initial - 1)
for i in 1:discard_initial
# Update the progress bar.
if progress && i >= next_update
ProgressLogging.@logprogress i / Ntotal
Expand Down Expand Up @@ -218,7 +218,7 @@ function mcmcsample(
sample, state = step(rng, model, sampler; kwargs...)

# Discard initial samples.
for _ in 2:discard_initial
for _ in 1:discard_initial
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
end
Expand Down
32 changes: 30 additions & 2 deletions src/stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,37 @@ struct Stepper{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K}
kwargs::K
end

Base.iterate(stp::Stepper) = step(stp.rng, stp.model, stp.sampler; stp.kwargs...)
# Initial sample.
function Base.iterate(stp::Stepper)
# Unpack iterator.
rng = stp.rng
model = stp.model
sampler = stp.sampler
kwargs = stp.kwargs
discard_initial = get(kwargs, :discard_initial, 0)::Int

# Start sampling algorithm and discard initial samples if desired.
sample, state = step(rng, model, sampler; kwargs...)
for _ in 1:discard_initial
sample, state = step(rng, model, sampler, state; kwargs...)
end
return sample, state
end

# Subsequent samples.
function Base.iterate(stp::Stepper, state)
return step(stp.rng, stp.model, stp.sampler, state; stp.kwargs...)
# Unpack iterator.
rng = stp.rng
model = stp.model
sampler = stp.sampler
kwargs = stp.kwargs
thinning = get(kwargs, :thinning, 1)::Int

# Return next sample, possibly after thinning the chain if desired.
for _ in 1:(thinning - 1)
_, state = step(rng, model, sampler, state; kwargs...)
end
return step(rng, model, sampler, state; kwargs...)
end

Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite()
Expand Down
52 changes: 43 additions & 9 deletions src/transducer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,58 @@ function Sample(
return Sample(rng, model, sampler, kwargs)
end

# Initial sample.
function Transducers.start(rf::Transducers.R_{<:Sample}, result)
sampler = Transducers.xform(rf)
# Unpack transducer.
td = Transducers.xform(rf)
rng = td.rng
model = td.model
sampler = td.sampler
kwargs = td.kwargs
discard_initial = get(kwargs, :discard_initial, 0)::Int

# Start sampling algorithm and discard initial samples if desired.
sample, state = step(rng, model, sampler; kwargs...)
for _ in 1:discard_initial
sample, state = step(rng, model, sampler, state; kwargs...)
end

return Transducers.wrap(
rf,
step(sampler.rng, sampler.model, sampler.sampler; sampler.kwargs...),
Transducers.start(Transducers.inner(rf), result),
rf, (sample, state), Transducers.start(Transducers.inner(rf), result)
)
end

# Subsequent samples.
function Transducers.next(rf::Transducers.R_{<:Sample}, result, input)
t = Transducers.xform(rf)
Transducers.wrapping(rf, result) do (sample, state), iresult
iresult2 = Transducers.next(Transducers.inner(rf), iresult, sample)
return step(t.rng, t.model, t.sampler, state; t.kwargs...), iresult2
# Unpack transducer.
td = Transducers.xform(rf)
rng = td.rng
model = td.model
sampler = td.sampler
kwargs = td.kwargs
thinning = get(kwargs, :thinning, 1)::Int

let rng = rng,
model = model,
sampler = sampler,
kwargs = kwargs,
thinning = thinning,
inner_rf = Transducers.inner(rf)

Transducers.wrapping(rf, result) do (sample, state), iresult
iresult2 = Transducers.next(inner_rf, iresult, sample)

# Perform thinning if desired.
for _ in 1:(thinning - 1)
_, state = step(rng, model, sampler, state; kwargs...)
end

return step(rng, model, sampler, state; kwargs...), iresult2
end
end
end

function Transducers.complete(rf::Transducers.R_{Sample}, result)
_private_state, inner_result = Transducers.unwrap(rf, result)
_, inner_result = Transducers.unwrap(rf, result)
return Transducers.complete(Transducers.inner(rf), inner_result)
end
101 changes: 75 additions & 26 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
@test chains isa Vector{<:MyChain}
@test length(chains) == 1000
@test all(x -> length(x.as) == length(x.bs) == N, chains)
@test all(ismissing(x.as[1]) for x in chains)

# test some statistical properties
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
Expand All @@ -147,9 +148,9 @@
# test reproducibility
Random.seed!(1234)
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain)

@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(ismissing(x.as[1]) for x in chains2)
@test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
@test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)

# Unexpected order of arguments.
str = "Number of chains (10) is greater than number of samples per chain (5)"
Expand Down Expand Up @@ -245,7 +246,7 @@

# Test output type and size.
@test chains isa Vector{<:MyChain}
@test all(c.as[1] === missing for c in chains)
@test all(ismissing(c.as[1]) for c in chains)
@test length(chains) == 1000
@test all(x -> length(x.as) == length(x.bs) == N, chains)

Expand All @@ -260,9 +261,9 @@
chains2 = sample(
MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain
)

@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(ismissing(c.as[1]) for c in chains2)
@test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
@test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)

# Unexpected order of arguments.
str = "Number of chains (10) is greater than number of samples per chain (5)"
Expand Down Expand Up @@ -330,7 +331,7 @@

# Test output type and size.
@test chains isa Vector{<:MyChain}
@test all(c.as[1] === missing for c in chains)
@test all(ismissing(c.as[1]) for c in chains)
@test length(chains) == 1000
@test all(x -> length(x.as) == length(x.bs) == N, chains)

Expand All @@ -343,9 +344,9 @@
# Test reproducibility.
Random.seed!(1234)
chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain)

@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(ismissing(c.as[1]) for c in chains2)
@test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
@test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)

# Unexpected order of arguments.
str = "Number of chains (10) is greater than number of samples per chain (5)"
Expand Down Expand Up @@ -415,6 +416,7 @@
progress=false,
chain_type=MyChain,
)
@test all(ismissing(c.as[1]) for c in chains_serial)

# Multi-threaded sampling
Random.seed!(1234)
Expand All @@ -427,12 +429,13 @@
progress=false,
chain_type=MyChain,
)
@test all(ismissing(c.as[1]) for c in chains_threads)
@test all(
c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads),
i in 1:N
c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads),
i in 2:N
)
@test all(
c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads),
c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads),
i in 1:N
)

Expand All @@ -447,12 +450,13 @@
progress=false,
chain_type=MyChain,
)
@test all(ismissing(c.as[1]) for c in chains_distributed)
@test all(
c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed),
i in 1:N
c1.as[i] == c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed),
i in 2:N
)
@test all(
c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed),
c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed),
i in 1:N
)
end
Expand All @@ -473,24 +477,41 @@
end

@testset "Discard initial samples" begin
chain = sample(MyModel(), MySampler(), 100; sleepy=true, discard_initial=50)
@test length(chain) == 100
# Create a chain and discard initial samples.
Random.seed!(1234)
N = 100
discard_initial = 50
chain = sample(MyModel(), MySampler(), N; discard_initial=discard_initial)
@test length(chain) == N
@test !ismissing(chain[1].a)

# Repeat sampling without discarding initial samples.
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
Random.seed!(1234)
ref_chain = sample(
MyModel(), MySampler(), N + discard_initial; progress=VERSION < v"1.6"
)
@test all(chain[i].a == ref_chain[i + discard_initial].a for i in 1:N)
@test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N)
end

@testset "Thin chain by a factor of `thinning`" begin
# Run a thinned chain with `N` samples thinned by factor of `thinning`.
Random.seed!(1234)
Random.seed!(100)
N = 100
thinning = 3
chain = sample(MyModel(), MySampler(), N; sleepy=true, thinning=thinning)
chain = sample(MyModel(), MySampler(), N; thinning=thinning)
@test length(chain) == N
@test ismissing(chain[1].a)

# Repeat sampling without thinning.
Random.seed!(1234)
ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy=true)
@test all(chain[i].a === ref_chain[(i - 1) * thinning + 1].a for i in 1:N)
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
Random.seed!(100)
ref_chain = sample(MyModel(), MySampler(), N * thinning; progress=VERSION < v"1.6")
@test all(chain[i].a == ref_chain[(i - 1) * thinning + 1].a for i in 2:N)
@test all(chain[i].b == ref_chain[(i - 1) * thinning + 1].b for i in 1:N)
end

@testset "Sample without predetermined N" begin
Expand All @@ -501,16 +522,44 @@
@test abs(bmean) <= 0.001 || length(chain) == 10_000

# Discard initial samples.
chain = sample(MyModel(), MySampler(); discard_initial=50)
Random.seed!(1234)
discard_initial = 50
chain = sample(MyModel(), MySampler(); discard_initial=discard_initial)
bmean = mean(x.b for x in chain)
@test !ismissing(chain[1].a)
@test abs(bmean) <= 0.001 || length(chain) == 10_000

# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
Random.seed!(1234)
N = length(chain)
ref_chain = sample(
MyModel(),
MySampler(),
N;
discard_initial=discard_initial,
progress=VERSION < v"1.6",
)
@test all(chain[i].a == ref_chain[i].a for i in 1:N)
@test all(chain[i].b == ref_chain[i].b for i in 1:N)

# Thin chain by a factor of `thinning`.
chain = sample(MyModel(), MySampler(); thinning=3)
Random.seed!(1234)
thinning = 3
chain = sample(MyModel(), MySampler(); thinning=thinning)
bmean = mean(x.b for x in chain)
@test ismissing(chain[1].a)
@test abs(bmean) <= 0.001 || length(chain) == 10_000

# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
Random.seed!(1234)
N = length(chain)
ref_chain = sample(
MyModel(), MySampler(), N; thinning=thinning, progress=VERSION < v"1.6"
)
@test all(chain[i].a == ref_chain[i].a for i in 2:N)
@test all(chain[i].b == ref_chain[i].b for i in 1:N)
end

@testset "Sample vector of `NamedTuple`s" begin
Expand Down
Loading

2 comments on commit 8d7f22f

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/61469

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.1.1 -m "<description of version>" 8d7f22f5a047a16b6870ebb15c0090331db8dcaa
git push origin v4.1.1

Please sign in to comment.