diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..1e72b507 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style="blue" diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml new file mode 100644 index 00000000..ec14da16 --- /dev/null +++ b/.github/workflows/Format.yml @@ -0,0 +1,31 @@ +name: Format + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - name: Format code + run: | + using Pkg + Pkg.add(; name="JuliaFormatter", uuid="98e50ef6-434e-11e9-1051-2b60c6c9e899") + using JuliaFormatter + format("."; verbose=true) + shell: julia --color=yes {0} + - uses: reviewdog/action-suggester@v1 + if: github.event_name == 'pull_request' + with: + tool_name: JuliaFormatter + fail_on_error: true diff --git a/README.md b/README.md index a2d40c34..ee186269 100644 --- a/README.md +++ b/README.md @@ -8,3 +8,4 @@ Abstract types and interfaces for Markov chain Monte Carlo methods. [![IntegrationTest](https://github.com/TuringLang/AbstractMCMC.jl/workflows/IntegrationTest/badge.svg?branch=master)](https://github.com/TuringLang/AbstractMCMC.jl/actions?query=workflow%3AIntegrationTest+branch%3Amaster) [![Codecov](https://codecov.io/gh/TuringLang/AbstractMCMC.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AbstractMCMC.jl) [![Coveralls](https://coveralls.io/repos/github/TuringLang/AbstractMCMC.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/AbstractMCMC.jl?branch=master) +[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) diff --git a/docs/make.jl b/docs/make.jl index 67978d1e..66d7619c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,26 +2,15 @@ using AbstractMCMC using Documenter using Random -DocMeta.setdocmeta!( - AbstractMCMC, - :DocTestSetup, - :(using AbstractMCMC); - recursive=true, -) +DocMeta.setdocmeta!(AbstractMCMC, :DocTestSetup, :(using AbstractMCMC); recursive=true) makedocs(; sitename="AbstractMCMC", format=Documenter.HTML(), modules=[AbstractMCMC], - pages=[ - "Home" => "index.md", - "api.md", - "design.md", - ], + pages=["Home" => "index.md", "api.md", "design.md"], strict=true, checkdocs=:exports, ) -deploydocs(; - repo="github.com/TuringLang/AbstractMCMC.jl.git", push_preview=true -) +deploydocs(; repo="github.com/TuringLang/AbstractMCMC.jl.git", push_preview=true) diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index ef23cb51..686924a8 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -1,16 +1,16 @@ module AbstractMCMC -import BangBang -import ConsoleProgressMonitor -import LoggingExtras -import ProgressLogging -import StatsBase -import TerminalLoggers -import Transducers - -import Distributed -import Logging -import Random +using BangBang: BangBang +using ConsoleProgressMonitor: ConsoleProgressMonitor +using LoggingExtras: LoggingExtras +using ProgressLogging: ProgressLogging +using StatsBase: StatsBase +using TerminalLoggers: TerminalLoggers +using Transducers: Transducers + +using Distributed: Distributed +using Logging: Logging +using Random: Random # Reexport sample using StatsBase: sample @@ -71,7 +71,6 @@ processes. """ struct MCMCDistributed <: AbstractMCMCEnsemble end - """ MCMCSerial diff --git a/src/deprecations.jl b/src/deprecations.jl index 1cc93d12..128f16d1 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,2 +1,2 @@ # Deprecate the old name AbstractMCMCParallel in favor of AbstractMCMCEnsemble -Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false \ No newline at end of file +Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false diff --git a/src/interface.jl b/src/interface.jl index 7b3daefb..eaecb492 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -30,24 +30,14 @@ be specified with the `chain_type` argument. By default, this method returns `samples`. """ function bundle_samples( - samples, - ::AbstractModel, - ::AbstractSampler, - ::Any, - ::Type; - kwargs... + samples, ::AbstractModel, ::AbstractSampler, ::Any, ::Type; kwargs... ) return samples end function bundle_samples( - samples::Vector, - ::AbstractModel, - ::AbstractSampler, - ::Any, - ::Type{Vector{T}}; - kwargs... -) where T + samples::Vector, ::AbstractModel, ::AbstractSampler, ::Any, ::Type{Vector{T}}; kwargs... +) where {T} return map(samples) do sample convert(T, sample) end @@ -74,24 +64,13 @@ sample is `sample`. The method can be called with and without a predefined number `N` of samples. """ -function samples( - sample, - ::AbstractModel, - ::AbstractSampler, - N::Integer; - kwargs... -) +function samples(sample, ::AbstractModel, ::AbstractSampler, N::Integer; kwargs...) ts = Vector{typeof(sample)}(undef, 0) sizehint!(ts, N) return ts end -function samples( - sample, - ::AbstractModel, - ::AbstractSampler; - kwargs... -) +function samples(sample, ::AbstractModel, ::AbstractSampler; kwargs...) return Vector{typeof(sample)}(undef, 0) end @@ -113,7 +92,7 @@ function save!!( ::AbstractModel, ::AbstractSampler, N::Integer; - kwargs... + kwargs..., ) s = BangBang.push!!(samples, sample) s !== samples && sizehint!(s, N) @@ -121,27 +100,15 @@ function save!!( end function save!!( - samples, - sample, - iteration::Integer, - ::AbstractModel, - ::AbstractSampler; - kwargs... + samples, sample, iteration::Integer, ::AbstractModel, ::AbstractSampler; kwargs... ) return BangBang.push!!(samples, sample) end # Deprecations Base.@deprecate transitions( - transition, - model::AbstractModel, - sampler::AbstractSampler, - N::Integer; - kwargs... + transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs... ) samples(transition, model, sampler, N; kwargs...) false Base.@deprecate transitions( - transition, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + transition, model::AbstractModel, sampler::AbstractSampler; kwargs... ) samples(transition, model, sampler; kwargs...) false diff --git a/src/logging.jl b/src/logging.jl index a550c532..04c41187 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -2,19 +2,21 @@ # and add a custom progress logger if the current logger does not seem to be able to handle # progress logs macro ifwithprogresslogger(progress, exprs...) - return quote - if $progress - if $hasprogresslevel($Logging.current_logger()) - $ProgressLogging.@withprogress $(exprs...) - else - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + return esc( + quote + if $progress + if $hasprogresslevel($Logging.current_logger()) $ProgressLogging.@withprogress $(exprs...) + else + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + $ProgressLogging.@withprogress $(exprs...) + end end + else + $(exprs[end]) end - else - $(exprs[end]) - end - end |> esc + end, + ) end # improved checks? @@ -31,13 +33,14 @@ function with_progresslogger(f, _module, logger) log._module !== _module || log.level != ProgressLogging.ProgressLevel end - Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) + return Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2)) end function progresslogger() # detect if code is running under IJulia since TerminalLogger does not work with IJulia # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia - if (Sys.iswindows() && VERSION < v"1.5.3") || (isdefined(Main, :IJulia) && Main.IJulia.inited) + if (Sys.iswindows() && VERSION < v"1.5.3") || + (isdefined(Main, :IJulia) && Main.IJulia.inited) return ConsoleProgressMonitor.ProgressLogger() else return TerminalLoggers.TerminalLogger() diff --git a/src/sample.jl b/src/sample.jl index e1ff1be7..b6fad3fe 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -12,12 +12,7 @@ function setprogress!(progress::Bool) return progress end -function StatsBase.sample( - model::AbstractModel, - sampler::AbstractSampler, - arg; - kwargs... -) +function StatsBase.sample(model::AbstractModel, sampler::AbstractSampler, arg; kwargs...) return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, arg; kwargs...) end @@ -31,7 +26,7 @@ function StatsBase.sample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - kwargs... + kwargs..., ) return mcmcsample(rng, model, sampler, N; kwargs...) end @@ -54,7 +49,7 @@ function StatsBase.sample( model::AbstractModel, sampler::AbstractSampler, isdone; - kwargs... + kwargs..., ) return mcmcsample(rng, model, sampler, isdone; kwargs...) end @@ -65,10 +60,11 @@ function StatsBase.sample( parallel::AbstractMCMCEnsemble, N::Integer, nchains::Integer; - kwargs... + kwargs..., ) - return StatsBase.sample(Random.GLOBAL_RNG, model, sampler, parallel, N, nchains; - kwargs...) + return StatsBase.sample( + Random.GLOBAL_RNG, model, sampler, parallel, N, nchains; kwargs... + ) end """ @@ -84,7 +80,7 @@ function StatsBase.sample( parallel::AbstractMCMCEnsemble, N::Integer, nchains::Integer; - kwargs... + kwargs..., ) return mcmcsample(rng, model, sampler, parallel, N, nchains; kwargs...) end @@ -96,13 +92,13 @@ function mcmcsample( model::AbstractModel, sampler::AbstractSampler, N::Integer; - progress = PROGRESS[], - progressname = "Sampling", - callback = nothing, - discard_initial = 0, - thinning = 1, + progress=PROGRESS[], + progressname="Sampling", + callback=nothing, + discard_initial=0, + thinning=1, chain_type::Type=Any, - kwargs... + kwargs..., ) # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") @@ -112,7 +108,7 @@ function mcmcsample( start = time() local state - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Determine threshold values for progress logging # (one update per 0.5% of progress) if progress @@ -127,7 +123,7 @@ function mcmcsample( for i in 1:(discard_initial - 1) # Update the progress bar. if progress && i >= next_update - ProgressLogging.@logprogress i/Ntotal + ProgressLogging.@logprogress i / Ntotal next_update = i + threshold end @@ -167,7 +163,8 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler, N; kwargs...) @@ -186,15 +183,15 @@ function mcmcsample( stats = SamplingStats(start, stop, duration) return bundle_samples( - samples, - model, + samples, + model, sampler, state, chain_type; stats=stats, discard_initial=discard_initial, thinning=thinning, - kwargs... + kwargs..., ) end @@ -204,19 +201,19 @@ function mcmcsample( sampler::AbstractSampler, isdone; chain_type::Type=Any, - progress = PROGRESS[], - progressname = "Convergence sampling", - callback = nothing, - discard_initial = 0, - thinning = 1, - kwargs... + progress=PROGRESS[], + progressname="Convergence sampling", + callback=nothing, + discard_initial=0, + thinning=1, + kwargs..., ) # Start the timer start = time() local state - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. sample, state = step(rng, model, sampler; kwargs...) @@ -247,7 +244,8 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler; kwargs...) @@ -264,15 +262,15 @@ function mcmcsample( # Wrap the samples up. return bundle_samples( - samples, + samples, model, - sampler, - state, - chain_type; + sampler, + state, + chain_type; stats=stats, discard_initial=discard_initial, thinning=thinning, - kwargs... + kwargs..., ) end @@ -283,9 +281,9 @@ function mcmcsample( ::MCMCThreads, N::Integer, nchains::Integer; - progress = PROGRESS[], - progressname = "Sampling ($(min(nchains, Threads.nthreads())) threads)", - kwargs... + progress=PROGRESS[], + progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", + kwargs..., ) # Check if actually multiple threads are used. if Threads.nthreads() == 1 @@ -311,7 +309,7 @@ function mcmcsample( # Set up a chains vector. chains = Vector{Any}(undef, nchains) - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Create a channel for progress logging. if progress channel = Channel{Bool}(length(interval)) @@ -330,7 +328,7 @@ function mcmcsample( while take!(channel) progresschains += 1 if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains/nchains + ProgressLogging.@logprogress progresschains / nchains nextprogresschains = progresschains + threshold end end @@ -339,7 +337,8 @@ function mcmcsample( Distributed.@async begin try - Distributed.@sync for (i, _rng, _model, _sampler) in zip(1:nchunks, rngs, models, samplers) + Distributed.@sync for (i, _rng, _model, _sampler) in + zip(1:nchunks, rngs, models, samplers) chainidxs = if i == nchunks ((i - 1) * chunksize + 1):nchains else @@ -350,8 +349,9 @@ function mcmcsample( Random.seed!(_rng, seeds[chainidx]) # Sample a chain and save it to the vector. - chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N; - progress = false, kwargs...) + chains[chainidx] = StatsBase.sample( + _rng, _model, _sampler, N; progress=false, kwargs... + ) # Update the progress bar. progress && put!(channel, true) @@ -376,9 +376,9 @@ function mcmcsample( ::MCMCDistributed, N::Integer, nchains::Integer; - progress = PROGRESS[], - progressname = "Sampling ($(Distributed.nworkers()) processes)", - kwargs... + progress=PROGRESS[], + progressname="Sampling ($(Distributed.nworkers()) processes)", + kwargs..., ) # Check if actually multiple processes are used. if Distributed.nworkers() == 1 @@ -397,7 +397,7 @@ function mcmcsample( pool = Distributed.CachingPool(Distributed.workers()) local chains - @ifwithprogresslogger progress name=progressname begin + @ifwithprogresslogger progress name = progressname begin # Create a channel for progress logging. if progress channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers())) @@ -416,7 +416,7 @@ function mcmcsample( while take!(channel) progresschains += 1 if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains/nchains + ProgressLogging.@logprogress progresschains / nchains nextprogresschains = progresschains + threshold end end @@ -430,8 +430,9 @@ function mcmcsample( Random.seed!(rng, seed) # Sample a chain. - chain = StatsBase.sample(rng, model, sampler, N; - progress = false, kwargs...) + chain = StatsBase.sample( + rng, model, sampler, N; progress=false, kwargs... + ) # Update the progress bar. progress && put!(channel, true) @@ -458,8 +459,8 @@ function mcmcsample( ::MCMCSerial, N::Integer, nchains::Integer; - progressname = "Sampling", - kwargs... + progressname="Sampling", + kwargs..., ) # Check if the number of chains is larger than the number of samples if nchains > N @@ -473,8 +474,11 @@ function mcmcsample( chains = map(enumerate(seeds)) do (i, seed) Random.seed!(rng, seed) return StatsBase.sample( - rng, model, sampler, N; - progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"), + rng, + model, + sampler, + N; + progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), kwargs..., ) end diff --git a/src/samplingstats.jl b/src/samplingstats.jl index dea2b653..c5820dff 100644 --- a/src/samplingstats.jl +++ b/src/samplingstats.jl @@ -13,4 +13,4 @@ struct SamplingStats start::Float64 stop::Float64 duration::Float64 -end \ No newline at end of file +end diff --git a/src/stepper.jl b/src/stepper.jl index 34391851..18867c58 100644 --- a/src/stepper.jl +++ b/src/stepper.jl @@ -13,11 +13,7 @@ end Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite() Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown() -function steps( - model::AbstractModel, - sampler::AbstractSampler; - kwargs... -) +function steps(model::AbstractModel, sampler::AbstractSampler; kwargs...) return steps(Random.GLOBAL_RNG, model, sampler; kwargs...) end @@ -46,10 +42,7 @@ true ``` """ function steps( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler; kwargs... ) return Stepper(rng, model, sampler, kwargs) end diff --git a/src/transducer.jl b/src/transducer.jl index 7aca51e0..51f9b358 100644 --- a/src/transducer.jl +++ b/src/transducer.jl @@ -1,4 +1,5 @@ -struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: Transducers.Transducer +struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: + Transducers.Transducer rng::A model::M sampler::S @@ -34,10 +35,7 @@ true ``` """ function Sample( - rng::Random.AbstractRNG, - model::AbstractModel, - sampler::AbstractSampler; - kwargs... + rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler; kwargs... ) return Sample(rng, model, sampler, kwargs) end diff --git a/test/deprecations.jl b/test/deprecations.jl index f866668c..dd53cb42 100644 --- a/test/deprecations.jl +++ b/test/deprecations.jl @@ -1,4 +1,4 @@ @testset "deprecations.jl" begin @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler()) @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler(), 3) -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index c3f108e1..e8f09589 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,7 @@ using TerminalLoggers: TerminalLogger using Transducers using Distributed -import Logging +using Logging: Logging using Random using Statistics using Test diff --git a/test/sample.jl b/test/sample.jl index 19242629..debb2238 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -5,12 +5,13 @@ Random.seed!(1234) N = 1_000 - chain = sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + chain = sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) @test length(LOGGERS) == 1 logger = first(LOGGERS) @test logger isa TeeLogger - @test logger.loggers[1].logger isa (Sys.iswindows() && VERSION < v"1.5.3" ? ProgressLogger : TerminalLogger) + @test logger.loggers[1].logger isa + (Sys.iswindows() && VERSION < v"1.5.3" ? ProgressLogger : TerminalLogger) @test logger.loggers[2].logger === CURRENT_LOGGER @test Logging.current_logger() === CURRENT_LOGGER @@ -20,10 +21,10 @@ # test some statistical properties tail_chain = @view chain[2:end] - @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 - @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 - @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 - @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 + @test mean(x.a for x in tail_chain) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in tail_chain) ≈ 1 / 12 atol = 5e-3 + @test mean(x.b for x in tail_chain) ≈ 0.0 atol = 5e-2 + @test var(x.b for x in tail_chain) ≈ 1 atol = 6e-2 end @testset "Juno" begin @@ -34,7 +35,7 @@ logger = JunoProgressLogger() Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) end @test length(LOGGERS) == 1 @@ -52,7 +53,7 @@ Random.seed!(1234) N = 10 - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) @test length(LOGGERS) == 1 logger = first(LOGGERS) @@ -74,7 +75,7 @@ logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1)) Logging.with_logger(logger) do - sample(MyModel(), MySampler(), N; sleepy = true, loggers = true) + sample(MyModel(), MySampler(), N; sleepy=true, loggers=true) end @test length(LOGGERS) == 1 @@ -84,21 +85,25 @@ @testset "Suppress output" begin logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; progress = false, sleepy = true) + sample(MyModel(), MySampler(), 100; progress=false, sleepy=true) end @test all(l.level > Logging.LogLevel(-1) for l in logs) # disable progress logging globally - @test !(@test_logs (:info, "progress logging is disabled globally") AbstractMCMC.setprogress!(false)) + @test !(@test_logs (:info, "progress logging is disabled globally") AbstractMCMC.setprogress!( + false + )) @test !AbstractMCMC.PROGRESS[] logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), 100; sleepy = true) + sample(MyModel(), MySampler(), 100; sleepy=true) end @test all(l.level > Logging.LogLevel(-1) for l in logs) # enable progress logging globally - @test (@test_logs (:info, "progress logging is enabled globally") AbstractMCMC.setprogress!(true)) + @test (@test_logs (:info, "progress logging is enabled globally") AbstractMCMC.setprogress!( + true + )) @test AbstractMCMC.PROGRESS[] end end @@ -106,8 +111,9 @@ @testset "Multithreaded sampling" begin if Threads.nthreads() == 1 warnregex = r"^Only a single thread available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), - 10, 10) + @test_logs (:warn, warnregex) sample( + MyModel(), MySampler(), MCMCThreads(), 10, 10 + ) end # No dedicated chains type @@ -118,8 +124,7 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain) # test output type and size @test chains isa Vector{<:MyChain} @@ -134,36 +139,43 @@ # test reproducibility Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; chain_type=MyChain) @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCThreads(), 5, 10; - chain_type = MyChain) + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCThreads(), 5, 10; chain_type=MyChain + ) # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; - progress = false, chain_type = MyChain) + sample( + MyModel(), + MySampler(), + MCMCThreads(), + 10_000, + 1000; + progress=false, + chain_type=MyChain, + ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) - + # Smoke test for nchains < nthreads if Threads.nthreads() == 2 - sample(MyModel(), MySampler(), MCMCThreads(), N, 1) + sample(MyModel(), MySampler(), MCMCThreads(), N, 1) end end @testset "Multicore sampling" begin if nworkers() == 1 warnregex = r"^Only a single process available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCDistributed(), - 10, 10; chain_type = MyChain) + @test_logs (:warn, warnregex) sample( + MyModel(), MySampler(), MCMCDistributed(), 10, 10; chain_type=MyChain + ) end # Add worker processes. @@ -188,8 +200,9 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; - chain_type = MyChain) + chains = sample( + MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain + ) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -205,22 +218,30 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000; - chain_type = MyChain) + chains2 = sample( + MyModel(), MySampler(), MCMCDistributed(), N, 1000; chain_type=MyChain + ) @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCDistributed(), 5, 10; - chain_type = MyChain) + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCDistributed(), 5, 10; chain_type=MyChain + ) # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 100; - progress = false, chain_type = MyChain) + sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 10_000, + 100; + progress=false, + chain_type=MyChain, + ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) end @@ -234,8 +255,7 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; - chain_type = MyChain) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -251,22 +271,28 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; - chain_type = MyChain) + chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) # Unexpected order of arguments. str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCSerial(), 5, 10; - chain_type = MyChain) + @test_logs (:warn, str) match_mode = :any sample( + MyModel(), MySampler(), MCMCSerial(), 5, 10; chain_type=MyChain + ) # Suppress output. logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCSerial(), 10_000, 100; - progress = false, chain_type = MyChain) + sample( + MyModel(), + MySampler(), + MCMCSerial(), + 10_000, + 100; + progress=false, + chain_type=MyChain, + ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) end @@ -278,46 +304,73 @@ # Serial sampling Random.seed!(1234) chains_serial = sample( - MyModel(), MySampler(), MCMCSerial(), N, nchains; - progress=false, chain_type=MyChain + MyModel(), + MySampler(), + MCMCSerial(), + N, + nchains; + progress=false, + chain_type=MyChain, ) # Multi-threaded sampling Random.seed!(1234) chains_threads = sample( - MyModel(), MySampler(), MCMCThreads(), N, nchains; - progress=false, chain_type=MyChain + MyModel(), + MySampler(), + MCMCThreads(), + N, + nchains; + progress=false, + chain_type=MyChain, + ) + @test all( + c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), + i in 1:N + ) + @test all( + c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), + i in 1:N ) - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) # Multi-core sampling Random.seed!(1234) chains_distributed = sample( - MyModel(), MySampler(), MCMCDistributed(), N, nchains; - progress=false, chain_type=MyChain + MyModel(), + MySampler(), + MCMCDistributed(), + N, + nchains; + progress=false, + chain_type=MyChain, + ) + @test all( + c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), + i in 1:N + ) + @test all( + c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), + i in 1:N ) - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) end @testset "Chain constructors" begin - chain1 = sample(MyModel(), MySampler(), 100; sleepy = true) - chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain) + chain1 = sample(MyModel(), MySampler(), 100; sleepy=true) + chain2 = sample(MyModel(), MySampler(), 100; sleepy=true, chain_type=MyChain) @test chain1 isa Vector{<:MySample} @test chain2 isa MyChain end @testset "Sample stats" begin - chain = sample(MyModel(), MySampler(), 1000; chain_type = MyChain) - + chain = sample(MyModel(), MySampler(), 1000; chain_type=MyChain) + @test chain.stats.stop >= chain.stats.start @test chain.stats.duration == chain.stats.stop - chain.stats.start end @testset "Discard initial samples" begin - chain = sample(MyModel(), MySampler(), 100; sleepy = true, discard_initial = 50) + chain = sample(MyModel(), MySampler(), 100; sleepy=true, discard_initial=50) @test length(chain) == 100 @test !ismissing(chain[1].a) end @@ -327,17 +380,16 @@ Random.seed!(1234) N = 100 thinning = 3 - chain = sample(MyModel(), MySampler(), N; sleepy = true, thinning = thinning) + chain = sample(MyModel(), MySampler(), N; sleepy=true, thinning=thinning) @test length(chain) == N @test ismissing(chain[1].a) # Repeat sampling without thinning. Random.seed!(1234) - ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy = true) + ref_chain = sample(MyModel(), MySampler(), N * thinning; sleepy=true) @test all(chain[i].a === ref_chain[(i - 1) * thinning + 1].a for i in 1:N) end - @testset "Sample without predetermined N" begin Random.seed!(1234) chain = sample(MyModel(), MySampler()) @@ -346,20 +398,20 @@ @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Discard initial samples. - chain = sample(MyModel(), MySampler(); discard_initial = 50) + chain = sample(MyModel(), MySampler(); discard_initial=50) bmean = mean(x.b for x in chain) @test !ismissing(chain[1].a) @test abs(bmean) <= 0.001 || length(chain) == 10_000 # Thin chain by a factor of `thinning`. - chain = sample(MyModel(), MySampler(); thinning = 3) + chain = sample(MyModel(), MySampler(); thinning=3) bmean = mean(x.b for x in chain) @test ismissing(chain[1].a) @test abs(bmean) <= 0.001 || length(chain) == 10_000 end @testset "Sample vector of `NamedTuple`s" begin - chain = sample(MyModel(), MySampler(), 1_000; chain_type = Vector{NamedTuple}) + chain = sample(MyModel(), MySampler(), 1_000; chain_type=Vector{NamedTuple}) # Check output type @test chain isa Vector{<:NamedTuple} @test length(chain) == 1_000 @@ -367,15 +419,17 @@ # Check some statistical properties @test ismissing(chain[1].a) - @test mean(x.a for x in view(chain, 2:1_000)) ≈ 0.5 atol=6e-2 - @test var(x.a for x in view(chain, 2:1_000)) ≈ 1 / 12 atol=1e-2 - @test mean(x.b for x in chain) ≈ 0 atol=0.1 - @test var(x.b for x in chain) ≈ 1 atol=0.15 + @test mean(x.a for x in view(chain, 2:1_000)) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in view(chain, 2:1_000)) ≈ 1 / 12 atol = 1e-2 + @test mean(x.b for x in chain) ≈ 0 atol = 0.1 + @test var(x.b for x in chain) ≈ 1 atol = 0.15 end - + @testset "Testing callbacks" begin - function count_iterations(rng, model, sampler, sample, state, i; iter_array, kwargs...) - push!(iter_array, i) + function count_iterations( + rng, model, sampler, sample, state, i; iter_array, kwargs... + ) + return push!(iter_array, i) end N = 100 it_array = Float64[] @@ -384,7 +438,9 @@ # sampling without predetermined N it_array = Float64[] - chain = sample(MyModel(), MySampler(); callback=count_iterations, iter_array=it_array) + chain = sample( + MyModel(), MySampler(); callback=count_iterations, iter_array=it_array + ) @test it_array == collect(1:size(chain, 1)) end end diff --git a/test/stepper.jl b/test/stepper.jl index bc75d637..1b570557 100644 --- a/test/stepper.jl +++ b/test/stepper.jl @@ -5,7 +5,7 @@ bs = [] iter = AbstractMCMC.steps(MyModel(), MySampler()) - iter = AbstractMCMC.steps(MyModel(), MySampler(); a = 1.0) # `a` shouldn't do anything + iter = AbstractMCMC.steps(MyModel(), MySampler(); a=1.0) # `a` shouldn't do anything for (count, t) in enumerate(iter) if count >= 1000 @@ -21,10 +21,10 @@ @test length(as) == length(bs) == 998 - @test mean(as) ≈ 0.5 atol=2e-2 - @test var(as) ≈ 1 / 12 atol=5e-3 - @test mean(bs) ≈ 0.0 atol=5e-2 - @test var(bs) ≈ 1 atol=5e-2 + @test mean(as) ≈ 0.5 atol = 2e-2 + @test var(as) ≈ 1 / 12 atol = 5e-3 + @test mean(bs) ≈ 0.0 atol = 5e-2 + @test var(bs) ≈ 1 atol = 5e-2 @test Base.IteratorSize(iter) == Base.IsInfinite() @test Base.IteratorEltype(iter) == Base.EltypeUnknown() diff --git a/test/transducer.jl b/test/transducer.jl index 910f9d70..f9e1a049 100644 --- a/test/transducer.jl +++ b/test/transducer.jl @@ -5,9 +5,8 @@ N = 1_000 local chain Logging.with_logger(TerminalLogger()) do - xf = AbstractMCMC.Sample(MyModel(), MySampler(); - sleepy = true, logger = true) - chain = withprogress(1:N; interval=1e-3) |> xf |> collect + xf = AbstractMCMC.Sample(MyModel(), MySampler(); sleepy=true, logger=true) + chain = collect(xf(withprogress(1:N; interval=1e-3))) end # test output type and size @@ -16,15 +15,15 @@ # test some statistical properties tail_chain = @view chain[2:end] - @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 - @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 - @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 - @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 + @test mean(x.a for x in tail_chain) ≈ 0.5 atol = 6e-2 + @test var(x.a for x in tail_chain) ≈ 1 / 12 atol = 5e-3 + @test mean(x.b for x in tail_chain) ≈ 0.0 atol = 5e-2 + @test var(x.b for x in tail_chain) ≈ 1 atol = 6e-2 end @testset "drop" begin xf = AbstractMCMC.Sample(MyModel(), MySampler()) - chain = 1:10 |> xf |> Drop(1) |> collect + chain = collect(Drop(1)(xf(1:10))) @test chain isa Vector{MySample{Float64,Float64}} @test length(chain) == 9 end @@ -37,7 +36,7 @@ OfType(MySample{Float64,Float64}), Map(x -> (x.a, x.b)), ) - as, bs = foldl(xf, 1:999; init = (Float64[], Float64[])) do (as, bs), (a, b) + as, bs = foldl(xf, 1:999; init=(Float64[], Float64[])) do (as, bs), (a, b) push!(as, a) push!(bs, b) as, bs @@ -45,9 +44,9 @@ @test length(as) == length(bs) == 998 - @test mean(as) ≈ 0.5 atol=2e-2 - @test var(as) ≈ 1 / 12 atol=5e-3 - @test mean(bs) ≈ 0.0 atol=5e-2 - @test var(bs) ≈ 1 atol=5e-2 + @test mean(as) ≈ 0.5 atol = 2e-2 + @test var(as) ≈ 1 / 12 atol = 5e-3 + @test mean(bs) ≈ 0.0 atol = 5e-2 + @test var(bs) ≈ 1 atol = 5e-2 end end diff --git a/test/utils.jl b/test/utils.jl index cd3543b7..32474639 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -20,10 +20,10 @@ function AbstractMCMC.step( rng::AbstractRNG, model::MyModel, sampler::MySampler, - state::Union{Nothing,Integer} = nothing; - sleepy = false, - loggers = false, - kwargs... + state::Union{Nothing,Integer}=nothing; + sleepy=false, + loggers=false, + kwargs..., ) # sample `a` is missing in the first step a = state === nothing ? missing : rand(rng) @@ -43,8 +43,8 @@ function AbstractMCMC.bundle_samples( sampler::MySampler, ::Any, ::Type{MyChain}; - stats = nothing, - kwargs... + stats=nothing, + kwargs..., ) as = [t.a for t in samples] bs = [t.b for t in samples] @@ -59,7 +59,7 @@ function isdone( samples, state, iteration::Int; - kwargs... + kwargs..., ) # Calculate the mean of x.b. bmean = mean(x.b for x in samples) @@ -72,11 +72,10 @@ function AbstractMCMC.sample(model, sampler::MySampler; kwargs...) end function AbstractMCMC.chainscat( - chain::Union{MyChain,Vector{<:MyChain}}, - chains::Union{MyChain,Vector{<:MyChain}}... + chain::Union{MyChain,Vector{<:MyChain}}, chains::Union{MyChain,Vector{<:MyChain}}... ) return vcat(chain, chains...) end # Conversion to NamedTuple -Base.convert(::Type{NamedTuple}, x::MySample) = (a = x.a, b = x.b) +Base.convert(::Type{NamedTuple}, x::MySample) = (a=x.a, b=x.b)