diff --git a/Project.toml b/Project.toml index 9e9bbbcea..bc6ddaedb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.33.1" +version = "0.33.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.28" +DynamicPPL = "0.28.1" Compat = "4.15.0" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index fceaa6a42..596e6e283 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -78,7 +78,7 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi # Short-circuits the tilde assume if `vn` is present in `context`. if has_conditioned_gibbs(context, vns) value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, values), vi + return value, broadcast_logpdf(right, value), vi end # Otherwise, falls back to the default behavior. @@ -90,8 +90,8 @@ function DynamicPPL.dot_tilde_assume( ) # Short-circuits the tilde assume if `vn` is present in `context`. if has_conditioned_gibbs(context, vns) - values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return values, broadcast_logpdf(right, values), vi + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, value), vi end # Otherwise, falls back to the default behavior. @@ -144,14 +144,14 @@ end Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. """ function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) - return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) + return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end -function DynamicPPL.condition( +function condition_gibbs( context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo, varinfos::DynamicPPL.AbstractVarInfo... ) - return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...) + return condition_gibbs(condition_gibbs(context, varinfo), varinfos...) end # Allow calling this on a `DynamicPPL.Model` directly. function condition_gibbs(model::DynamicPPL.Model, values...) @@ -238,6 +238,9 @@ function Gibbs(algs::Pair...) return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) end +# TODO: Remove when no longer needed. +DynamicPPL.getspace(::Gibbs) = () + struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} vi::V states::S @@ -252,6 +255,7 @@ function DynamicPPL.initialstep( model::DynamicPPL.Model, spl::DynamicPPL.Sampler{<:Gibbs}, vi_base::DynamicPPL.AbstractVarInfo; + initial_params=nothing, kwargs..., ) alg = spl.alg @@ -260,15 +264,35 @@ function DynamicPPL.initialstep( # 1. Run the model once to get the varnames present + initial values to condition on. vi_base = DynamicPPL.VarInfo(model) + + # Simple way of setting the initial parameters: set them in the `vi_base` + # if they are given so they propagate to the subset varinfos used by each sampler. + if initial_params !== nothing + vi_base = DynamicPPL.unflatten(vi_base, initial_params) + end + + # Create the varinfos for each sampler. varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) + initial_params_all = if initial_params === nothing + fill(nothing, length(varnames)) + else + # Extract from the `vi_base`, which should have the values set correctly from above. + map(vi -> vi[:], varinfos) + end # 2. Construct a varinfo for every vn + sampler combo. - states_and_varinfos = map(samplers, varinfos) do sampler_local, varinfo_local + states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local # Construct the conditional model. model_local = make_conditional(model, varinfo_local, varinfos) # Take initial step. - new_state_local = last(AbstractMCMC.step(rng, model_local, sampler_local; kwargs...)) + new_state_local = last(AbstractMCMC.step( + rng, model_local, sampler_local; + # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. + # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. + initial_params=initial_params_local, + kwargs... + )) # Return the new state and the invlinked `varinfo`. vi_local_state = Turing.Inference.varinfo(new_state_local) @@ -284,7 +308,7 @@ function DynamicPPL.initialstep( varinfos = map(last, states_and_varinfos) # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!(varinfos, vi_base, 1) + varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1) # Merge the updated initial varinfo with the rest of the varinfos + update the logp. vi = DynamicPPL.setlogp!!( reduce(merge, varinfos_new), @@ -365,12 +389,7 @@ function gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, s end # TODO: Remove `rng`? -""" - recompute_logprob!!(rng, model, sampler, state) - -Recompute the log-probability of the `model` based on the given `state` and return the resulting state. -""" -function recompute_logprob!!( +function Turing.Inference.recompute_logprob!!( rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, @@ -436,7 +455,7 @@ function gibbs_step_inner( state_local, state_previous ) - current_state_local = recompute_logprob!!( + state_local = Turing.Inference.recompute_logprob!!( rng, model_local, sampler_local, @@ -450,7 +469,7 @@ function gibbs_step_inner( rng, model_local, sampler_local, - current_state_local; + state_local; kwargs..., ), ) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index cc43451ea..5f05121e7 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -114,6 +114,8 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain end end +DynamicPPL.getspace(::ExternalSampler) = () + """ requires_unconstrained_space(sampler::ExternalSampler) diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index 2ccc38173..65aaa177b 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -17,10 +17,61 @@ function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transit return transition_to_turing(parent(f), transition) end +""" + getmodel(f) + +Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. +""" +getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f)) +getmodel(f::DynamicPPL.LogDensityFunction) = f.model + +""" + setmodel(f, model[, adtype]) + +Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. + +!!! warning + Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a + `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` + might require recompilation of the gradient tape, depending on the AD backend. +""" +function setmodel( + f::LogDensityProblemsAD.ADGradientWrapper, + model::DynamicPPL.Model, + adtype::ADTypes.AbstractADType +) + # TODO: Should we handle `SciMLBase.NoAD`? + # For an `ADGradientWrapper` we do the following: + # 1. Update the `Model` in the underlying `LogDensityFunction`. + # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype` + # to ensure that the recompilation of gradient tapes, etc. also occur. For example, + # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just + # replacing the corresponding field with the new model won't be sufficient to obtain + # the correct gradients. + return LogDensityProblemsAD.ADgradient(adtype, setmodel(parent(f), model)) +end +function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) + return Accessors.@set f.model = model +end + +function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper) + return varinfo_from_logdensityfn(parent(f)) +end +varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo + +function varinfo(state::TuringState) + θ = getparams(getmodel(state.logdensity), state.state) + # TODO: Do we need to link here first? + return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ) +end + # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: # https://github.com/TuringLang/AbstractMCMC.jl/pull/86 getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ +function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState) + return getparams(model, state.transition) +end getstats(transition::AdvancedHMC.Transition) = transition.stat getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params @@ -33,13 +84,59 @@ function setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo) return Accessors.@set f.ℓ = setvarinfo(f.ℓ, varinfo) end +""" + recompute_logprob!!(rng, model, sampler, state) + +Recompute the log-probability of the `model` based on the given `state` and return the resulting state. +""" +function recompute_logprob!!( + rng::Random.AbstractRNG, # TODO: Do we need the `rng` here? + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:ExternalSampler}, + state, +) + # Re-using the log-density function from the `state` and updating only the `model` field, + # since the `model` might now contain different conditioning values. + f = setmodel(state.logdensity, model, sampler.alg.adtype) + # Recompute the log-probability with the new `model`. + state_inner = recompute_logprob!!( + rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state + ) + return state_to_turing(f, state_inner) +end + +function recompute_logprob!!( + rng::Random.AbstractRNG, + model::AbstractMCMC.LogDensityModel, + sampler::AdvancedHMC.AbstractHMCSampler, + state::AdvancedHMC.HMCState, +) + # Construct hamiltionian. + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + # Re-compute the log-probability and gradient. + return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, state.transition.z.θ, state.transition.z.r + ) +end + +function recompute_logprob!!( + rng::Random.AbstractRNG, + model::AbstractMCMC.LogDensityModel, + sampler::AdvancedMH.MetropolisHastings, + state::AdvancedMH.Transition, +) + logdensity = model.logdensity + return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params) +end + +# TODO: Do we also support `resume`, etc? function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler_wrapper::Sampler{<:ExternalSampler}; initial_state=nothing, initial_params=nothing, - kwargs... + kwargs..., ) alg = sampler_wrapper.alg sampler = alg.sampler @@ -69,7 +166,12 @@ function AbstractMCMC.step( ) else transition_inner, state_inner = AbstractMCMC.step( - rng, AbstractMCMC.LogDensityModel(f), sampler, initial_state; initial_params, kwargs... + rng, + AbstractMCMC.LogDensityModel(f), + sampler, + initial_state; + initial_params, + kwargs..., ) end # Update the `state` @@ -81,7 +183,7 @@ function AbstractMCMC.step( model::DynamicPPL.Model, sampler_wrapper::Sampler{<:ExternalSampler}, state::TuringState; - kwargs... + kwargs..., ) sampler = sampler_wrapper.alg.sampler f = state.logdensity diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl index c8d42e338..0f0740f14 100644 --- a/test/experimental/gibbs.jl +++ b/test/experimental/gibbs.jl @@ -2,11 +2,14 @@ module ExperimentalGibbsTests using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo using ..NumericalTests: check_MoGtest_default, check_MoGtest_default_z_vector, check_gdemo, - check_numerical + check_numerical, two_sample_test using DynamicPPL using Random using Test using Turing +using Turing.Inference: AdvancedHMC, AdvancedMH +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff function check_transition_varnames( transition::Turing.Inference.Transition, @@ -32,122 +35,159 @@ const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false has_dot_assume(::Model) = true -# Likely an issue with not linking correctly. -@testset "Demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - # Run one sampler on variables starting with `s` and another on variables starting with `m`. - vns_s = filter(vns) do vn - DynamicPPL.getsym(vn) == :s - end - vns_m = filter(vns) do vn - DynamicPPL.getsym(vn) == :m - end +@testset "Gibbs using `condition`" begin + @testset "Demo models" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + # Run one sampler on variables starting with `s` and another on variables starting with `m`. + vns_s = filter(vns) do vn + DynamicPPL.getsym(vn) == :s + end + vns_m = filter(vns) do vn + DynamicPPL.getsym(vn) == :m + end - samplers = [ - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => NUTS(), - ), - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => HMC(0.01, 4), - ) - ] + samplers = [ + Turing.Experimental.Gibbs( + vns_s => NUTS(), + vns_m => NUTS(), + ), + Turing.Experimental.Gibbs( + vns_s => NUTS(), + vns_m => HMC(0.01, 4), + ) + ] - if !has_dot_assume(model) - # Add in some MH samplers, which are not compatible with `.~`. - append!( - samplers, - [ - Turing.Experimental.Gibbs( - vns_s => HMC(0.01, 4), - vns_m => MH(), - ), - Turing.Experimental.Gibbs( - vns_s => MH(), - vns_m => HMC(0.01, 4), - ) - ] - ) - end + if !has_dot_assume(model) + # Add in some MH samplers, which are not compatible with `.~`. + append!( + samplers, + [ + Turing.Experimental.Gibbs( + vns_s => HMC(0.01, 4), + vns_m => MH(), + ), + Turing.Experimental.Gibbs( + vns_s => MH(), + vns_m => HMC(0.01, 4), + ) + ] + ) + end - @testset "$sampler" for sampler in samplers - # Check that taking steps performs as expected. - rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) + @testset "$sampler" for sampler in samplers + # Check that taking steps performs as expected. + rng = Random.default_rng() + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) check_transition_varnames(transition, vns) + for _ = 1:5 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) + check_transition_varnames(transition, vns) + end end - end - end -end -@testset "Gibbs using `condition`" begin - @testset "demo_assume_dot_observe" begin - model = DynamicPPL.TestUtils.demo_assume_dot_observe() + @testset "comparison with 'gold-standard' samples" begin + num_iterations = 1_000 + thinning = 10 + num_chains = 4 - # Sample! - rng = Random.default_rng() - vns = [@varname(s), @varname(m)] - sampler = Turing.Experimental.Gibbs(map(Base.Fix2(Pair, MH()), vns)...) + # Determine initial parameters to make comparison as fair as possible. + posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) + initial_params = DynamicPPL.TestUtils.update_values!!( + DynamicPPL.VarInfo(model), + posterior_mean, + DynamicPPL.TestUtils.varnames(model), + )[:] + initial_params = fill(initial_params, num_chains) - @testset "step" begin - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - check_transition_varnames(transition, vns) - end - end + # Sampler to use for Gibbs components. + sampler_inner = HMC(0.1, 32) + sampler = Turing.Experimental.Gibbs( + vns_s => sampler_inner, + vns_m => sampler_inner, + ) + Random.seed!(42) + chain = sample( + model, + sampler, + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + discard_initial=1_000, + thinning=thinning + ) - @testset "sample" begin - chain = sample(model, sampler, 1000; progress=false) - @test size(chain, 1) == 1000 - display(mean(chain)) - end - end + # "Ground truth" samples. + # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. + Random.seed!(42) + chain_true = sample( + model, + NUTS(), + MCMCThreads(), + num_iterations, + num_chains; + progress=false, + initial_params=initial_params, + thinning=thinning, + ) - @testset "gdemo with CSMC & ESS" begin - Random.seed!(100) - alg = Turing.Experimental.Gibbs(@varname(s) => CSMC(15), @varname(m) => ESS()) - chain = sample(gdemo(1.5, 2.0), alg, 10_000; progress=false) - check_gdemo(chain) + # Perform KS test to ensure that the chains are similar. + xs = Array(chain) + xs_true = Array(chain_true) + for i = 1:size(xs, 2) + @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) + # Let's make sure that the significance level is not too low by + # checking that the KS test fails for some simple transformations. + # TODO: Replace the heuristic below with closed-form implementations + # of the targets, once they are implemented in DynamicPPL. + @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) + @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) + end + end + end end @testset "multiple varnames" begin rng = Random.default_rng() - # With both `s` and `m` as random. - model = gdemo(1.5, 2.0) - vns = (@varname(s), @varname(m)) - alg = Turing.Experimental.Gibbs(vns => MH()) + @testset "with both `s` and `m` as random" begin + model = gdemo(1.5, 2.0) + vns = (@varname(s), @varname(m)) + alg = Turing.Experimental.Gibbs(vns => MH()) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) check_transition_varnames(transition, vns) - end + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end - # `sample` - chain = sample(model, alg, 10_000; progress=false) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.4) + # `sample` + Random.seed!(42) + chain = sample(model, alg, 10_000; progress=false) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) + end - # Without `m` as random. - model = gdemo(1.5, 2.0) | (m = 7 / 6,) - vns = (@varname(s),) - alg = Turing.Experimental.Gibbs(vns => MH()) + @testset "without `m` as random" begin + model = gdemo(1.5, 2.0) | (m=7 / 6,) + vns = (@varname(s),) + alg = Turing.Experimental.Gibbs(vns => MH()) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) check_transition_varnames(transition, vns) + for _ in 1:5 + transition, state = AbstractMCMC.step( + rng, model, DynamicPPL.Sampler(alg), state + ) + check_transition_varnames(transition, vns) + end end end @@ -169,6 +209,7 @@ end end # Sample! + Random.seed!(42) chain = sample(MoGtest_default, alg, 1000; progress=false) check_MoGtest_default(chain, atol = 0.2) end @@ -191,9 +232,39 @@ end end # Sample! + Random.seed!(42) chain = sample(model, alg, 1000; progress=false) check_MoGtest_default_z_vector(chain, atol = 0.2) end + + @testset "externsalsampler" begin + @model function demo_gibbs_external() + m1 ~ Normal() + m2 ~ Normal() + + -1 ~ Normal(m1, 1) + +1 ~ Normal(m1 + m2, 1) + + return (; m1, m2) + end + + model = demo_gibbs_external() + samplers_inner = [ + externalsampler(AdvancedMH.RWMH(1)), + externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoForwardDiff()), + externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff()), + externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff(compile=true)), + ] + @testset "$(sampler_inner)" for sampler_inner in samplers_inner + sampler = Turing.Experimental.Gibbs( + @varname(m1) => sampler_inner, + @varname(m2) => sampler_inner, + ) + Random.seed!(42) + chain = sample(model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0) + check_numerical(chain, [:m1, :m2], [-0.2, 0.6], atol=0.1) + end + end end end diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index cb583b517..c44c502c1 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -3,6 +3,7 @@ module NumericalTests using Distributions using MCMCChains: namesingroup using Test: @test, @testset +using HypothesisTests: HypothesisTests export check_MoGtest_default, check_MoGtest_default_z_vector, check_dist_numerical, check_gdemo, check_numerical @@ -81,4 +82,33 @@ function check_MoGtest_default_z_vector(chain; atol=0.2, rtol=0.0) atol=atol, rtol=rtol) end +""" + two_sample_test(xs_left, xs_right; α=1e-3, warn_on_fail=false) + +Perform a two-sample hypothesis test on the two samples `xs_left` and `xs_right`. + +Currently the test performed is a Kolmogorov-Smirnov (KS) test. + +# Arguments +- `xs_left::AbstractVector`: samples from the first distribution. +- `xs_right::AbstractVector`: samples from the second distribution. + +# Keyword arguments +- `α::Real`: significance level for the test. Default: `1e-3`. +- `warn_on_fail::Bool`: whether to warn if the test fails. Default: `false`. + Makes failures a bit more informative. +""" +function two_sample_test(xs_left, xs_right; α=1e-3, warn_on_fail=false) + t = HypothesisTests.ApproximateTwoSampleKSTest(xs_left, xs_right) + # Just a way to make the logs a bit more informative in case of failure. + if HypothesisTests.pvalue(t) > α + true + else + warn_on_fail && @warn "Two-sample AD test failed with p-value $(HypothesisTests.pvalue(t))" + warn_on_fail && @warn "Means of the two samples: $(mean(xs_left)), $(mean(xs_right))" + warn_on_fail && @warn "Variances of the two samples: $(var(xs_left)), $(var(xs_right))" + false + end +end + end