From ddac60fb57b8dba6955922cec46573b310dbc0c1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 1 Mar 2023 00:39:04 +0100 Subject: [PATCH] Add compatibility with MCMCDiagnosticTools v0.3 (#401) * Bump MCMCDiagnosticTools compat * Update imported/exported methods * Remove type constraint on classifier * Overload and export mcse * Overload and update ess and rhat * Update summarystats * Update tests * Increment major version * Rename ess.jl to ess_rhat.jl * Add back ess_per_sec * Fix bug constructing ess_per_sec * Update ess_rhat tests * Test mcse * Update docs * Remove deprecations * Remove unused import * Revert "Fix MLJDecisionTreeInterface to 0.3.0 (#402)" This reverts commit 991f10b99fca0c6c74608df62d64f17df213f680. * Always include ess_per_sec in table * Use isequal to pass with missing values * Use isequal for missing * Remove naive_se Fixes #351 * Test Tables interface before loading StatsPlots DataValues (a StatsPlots dependency) pirates a convert method that causes the Tables equality tests with `missing` to fail. See https://github.com/queryverse/DataValues.jl * Apply suggestions from code review Co-authored-by: David Widmann --------- Co-authored-by: David Widmann --- Project.toml | 5 +-- docs/Project.toml | 4 +- docs/src/diagnostics.md | 3 +- src/MCMCChains.jl | 27 ++++--------- src/ess.jl | 33 ---------------- src/ess_rhat.jl | 85 +++++++++++++++++++++++++++++++++++++++++ src/mcse.jl | 22 +++++++++++ src/rstar.jl | 4 +- src/stats.jl | 36 +++++++++++------ test/Project.toml | 4 +- test/ess_rhat_tests.jl | 55 ++++++++++++++++++++++++++ test/ess_tests.jl | 54 -------------------------- test/mcse_tests.jl | 34 +++++++++++++++++ test/runtests.jl | 18 +++++---- test/summarize_tests.jl | 6 +-- test/tables_tests.jl | 29 +++++++------- 16 files changed, 268 insertions(+), 151 deletions(-) delete mode 100644 src/ess.jl create mode 100644 src/ess_rhat.jl create mode 100644 src/mcse.jl create mode 100644 test/ess_rhat_tests.jl delete mode 100644 test/ess_tests.jl create mode 100644 test/mcse_tests.jl diff --git a/Project.toml b/Project.toml index 8fb60638..5fc4b7ce 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "5.7.1" +version = "6.0.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -21,7 +21,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" @@ -35,7 +34,7 @@ Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" Formatting = "0.4" IteratorInterfaceExtensions = "0.1.1, 1" KernelDensity = "0.6.2" -MCMCDiagnosticTools = "0.2" +MCMCDiagnosticTools = "0.3" MLJModelInterface = "0.3.5, 0.4, 1.0" NaturalSort = "1" OrderedCollections = "1.4" diff --git a/docs/Project.toml b/docs/Project.toml index 955f10b9..8026a7ac 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,8 +15,8 @@ CategoricalArrays = "0.8, 0.9, 0.10" DataFrames = "0.22, 1" Documenter = "0.26, 0.27" Gadfly = "1.3.4" -MCMCChains = "5" +MCMCChains = "6" MLJBase = "0.19, 0.20, 0.21" -MLJDecisionTreeInterface = "=0.3.0" +MLJDecisionTreeInterface = "0.3" StatsPlots = "0.14, 0.15" julia = "1.7" diff --git a/docs/src/diagnostics.md b/docs/src/diagnostics.md index 496d885c..af2c7195 100644 --- a/docs/src/diagnostics.md +++ b/docs/src/diagnostics.md @@ -9,6 +9,7 @@ Pages = [ "heideldiag.jl", "rafterydiag.jl", "rstar.jl", - "ess.jl" + "ess_rhat.jl", + "mcse.jl", ] ``` diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index 51a37a9f..3422620f 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -24,7 +24,6 @@ import IteratorInterfaceExtensions import LinearAlgebra import Random -import Serialization import Statistics: std, cor, mean, var, mean! export Chains, chains, chainscat @@ -36,13 +35,15 @@ export ChainDataFrame export summarize # Reexport diagnostics functions -using MCMCDiagnosticTools: discretediag, ess_rhat, ESSMethod, FFTESSMethod, BDAESSMethod, - gelmandiag, gelmandiag_multivariate, gewekediag, heideldiag, rafterydiag, rstar +using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod, + BDAAutocovMethod, gelmandiag, gelmandiag_multivariate, gewekediag, heideldiag, mcse, + rafterydiag, rhat, rstar export discretediag -export ess_rhat, ESSMethod, FFTESSMethod, BDAESSMethod +export ess, ess_rhat, rhat, AutocovMethod, FFTAutocovMethod, BDAAutocovMethod export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag +export mcse export rafterydiag export rstar @@ -69,13 +70,14 @@ end include("utils.jl") include("chains.jl") include("constructors.jl") -include("ess.jl") +include("ess_rhat.jl") include("summarize.jl") include("discretediag.jl") include("fileio.jl") include("gelmandiag.jl") include("gewekediag.jl") include("heideldiag.jl") +include("mcse.jl") include("rafterydiag.jl") include("sampling.jl") include("stats.jl") @@ -84,19 +86,4 @@ include("plot.jl") include("tables.jl") include("rstar.jl") -# deprecations -# TODO: Remove dependency on Serialization if this deprecation is removed -# somehow `@deprecate` doesn't work with qualified function names, -# so we use the following hack -const _read = Base.read -const _write = Base.write -Base.@deprecate _read( - f::AbstractString, - ::Type{T} -) where {T<:Chains} Serialization.deserialize(f) false -Base.@deprecate _write( - f::AbstractString, - c::Chains -) Serialization.serialize(f, c) false - end # module diff --git a/src/ess.jl b/src/ess.jl deleted file mode 100644 index 4669aba5..00000000 --- a/src/ess.jl +++ /dev/null @@ -1,33 +0,0 @@ -""" - ess_rhat(chains::Chains; duration=compute_duration, kwargs...) - -Estimate the effective sample size and the potential scale reduction. - -ESS per second options include `duration=MCMCChains.compute_duration` (the default) -and `duration=MCMCChains.wall_duration`. -""" -function MCMCDiagnosticTools.ess_rhat( - chains::Chains; - sections = _default_sections(chains), duration = compute_duration, kwargs... -) - # Subset the chain - _chains = Chains(chains, _clean_sections(chains, sections)) - - # Estimate the effective sample size and rhat - ess, rhat = MCMCDiagnosticTools.ess_rhat( - _permutedims_diagnostics(_chains.value.data); - kwargs..., - ) - - # Calculate ESS/minute if available - dur = duration(chains) - - # Convert to NamedTuple - nt = if dur === missing - merge((parameters = names(_chains),), (ess = ess, rhat = rhat)) - else - merge((parameters = names(_chains),), (ess = ess, rhat = rhat, ess_per_sec=ess/dur)) - end - - return ChainDataFrame("ESS", nt) -end diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl new file mode 100644 index 00000000..70d86d2b --- /dev/null +++ b/src/ess_rhat.jl @@ -0,0 +1,85 @@ +""" + ess(chains::Chains; duration=compute_duration, kwargs...) + +Estimate the effective sample size. + +ESS per second options include `duration=MCMCChains.compute_duration` (the default) +and `duration=MCMCChains.wall_duration`. +""" +function MCMCDiagnosticTools.ess( + chains::Chains; + sections = _default_sections(chains), duration = compute_duration, kwargs... +) + # Subset the chain + _chains = Chains(chains, _clean_sections(chains, sections)) + + # Estimate the effective sample size + ess = MCMCDiagnosticTools.ess( + _permutedims_diagnostics(_chains.value.data); + kwargs..., + ) + + # Calculate ESS/minute if available + dur = duration(chains) + + # Convert to NamedTuple + ess_per_sec = ess ./ dur + nt = merge((parameters = names(_chains),), (; ess, ess_per_sec)) + + return ChainDataFrame("ESS", nt) +end + +""" + rhat(chains::Chains; kwargs...) + +Estimate the ``\\widehat{R}`` diagnostic. +""" +function MCMCDiagnosticTools.rhat( + chains::Chains; + sections = _default_sections(chains), kwargs... +) + # Subset the chain + _chains = Chains(chains, _clean_sections(chains, sections)) + + # Estimate the rhat + rhat = MCMCDiagnosticTools.rhat( + _permutedims_diagnostics(_chains.value.data); + kwargs..., + ) + + # Convert to NamedTuple + nt = merge((parameters = names(_chains),), (; rhat)) + + return ChainDataFrame("R-hat", nt) +end + +""" + ess_rhat(chains::Chains; duration=compute_duration, kwargs...) + +Estimate the effective sample size and the ``\\widehat{R}`` diagnostic + +ESS per second options include `duration=MCMCChains.compute_duration` (the default) +and `duration=MCMCChains.wall_duration`. +""" +function MCMCDiagnosticTools.ess_rhat( + chains::Chains; + sections = _default_sections(chains), duration = compute_duration, kwargs... +) + # Subset the chain + _chains = Chains(chains, _clean_sections(chains, sections)) + + # Estimate the effective sample size and rhat + ess_rhat = MCMCDiagnosticTools.ess_rhat( + _permutedims_diagnostics(_chains.value.data); + kwargs..., + ) + + # Calculate ESS/minute if available + dur = duration(chains) + + # Convert to NamedTuple + ess_per_sec = ess_rhat.ess ./ dur + nt = merge((parameters = names(_chains),), ess_rhat, (; ess_per_sec)) + + return ChainDataFrame("ESS/R-hat", nt) +end diff --git a/src/mcse.jl b/src/mcse.jl new file mode 100644 index 00000000..78ae2552 --- /dev/null +++ b/src/mcse.jl @@ -0,0 +1,22 @@ +""" + mcse(chains::Chains; duration=compute_duration, kwargs...) + +Estimate the Monte Carlo standard error. +""" +function MCMCDiagnosticTools.mcse( + chains::Chains; + sections = _default_sections(chains), kwargs... +) + # Subset the chain + _chains = Chains(chains, _clean_sections(chains, sections)) + + # Estimate the effective sample size + mcse = MCMCDiagnosticTools.mcse( + _permutedims_diagnostics(_chains.value.data); + kwargs..., + ) + + nt = merge((parameters = names(_chains),), (; mcse)) + + return ChainDataFrame("MCSE", nt) +end diff --git a/src/rstar.jl b/src/rstar.jl index 26b9b6c4..fe551966 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -38,14 +38,14 @@ true ``` """ function MCMCDiagnosticTools.rstar( - classif::MLJModelInterface.Supervised, chn::Chains; kwargs... + classif, chn::Chains; kwargs... ) return MCMCDiagnosticTools.rstar(Random.GLOBAL_RNG, classif, chn; kwargs...) end function MCMCDiagnosticTools.rstar( rng::Random.AbstractRNG, - classif::MLJModelInterface.Supervised, + classif, chn::Chains; sections = _default_sections(chn), kwargs... diff --git a/src/stats.jl b/src/stats.jl index f299e0e4..ba744d64 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -270,14 +270,13 @@ end chains; sections = _default_sections(chains), append_chains= true, - method::AbstractESSMethod = ESSMethod(), + autocov_method::AbstractAutocovMethod = AutocovMethod(), maxlag = 250, - etype = :bm, kwargs... ) -Compute the mean, standard deviation, naive standard error, Monte Carlo standard error, -and effective sample size for each parameter in the chain. +Compute the mean, standard deviation, Monte Carlo standard error, bulk- and tail- effective +sample size, and ``\\widehat{R}`` diagnostic for each parameter in the chain. Setting `append_chains=false` will return a vector of dataframes containing the summary statistics for each chain. @@ -288,27 +287,42 @@ function summarystats( chains::Chains; sections = _default_sections(chains), append_chains::Bool = true, - method::MCMCDiagnosticTools.AbstractESSMethod = ESSMethod(), + autocov_method::MCMCDiagnosticTools.AbstractAutocovMethod = AutocovMethod(), maxlag = 250, - etype = :bm, kwargs... ) # Store everything. - funs = [mean∘cskip, std∘cskip, sem∘cskip, x -> MCMCDiagnosticTools.mcse(cskip(x); method=etype, kwargs...)] - func_names = [:mean, :std, :naive_se, :mcse] + funs = [mean∘cskip, std∘cskip] + func_names = [:mean, :std] # Subset the chain. _chains = Chains(chains, _clean_sections(chains, sections)) - # Calculate ESS separately. - ess_df = MCMCDiagnosticTools.ess_rhat(_chains; sections = nothing, method = method, maxlag = maxlag) + # Calculate MCSE and ESS/R-hat separately. + mcse_df = MCMCDiagnosticTools.mcse( + _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, + ) + ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat( + _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank + ) + ess_tail_df = MCMCDiagnosticTools.ess( + _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail + ) + nt_additional = ( + mcse=mcse_df.nt.mcse, + ess_bulk=ess_rhat_rank_df.nt.ess, + ess_tail=ess_tail_df.nt.ess, + rhat=ess_rhat_rank_df.nt.rhat, + ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec, + ) + additional_df = ChainDataFrame("Additional", nt_additional) # Summarize. summary_df = summarize( _chains, funs...; func_names = func_names, append_chains = append_chains, - additional_df = ess_df, + additional_df = additional_df, name = "Summary Statistics", sections = nothing ) diff --git a/test/Project.toml b/test/Project.toml index 783e2882..613af823 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,9 +30,9 @@ Documenter = "0.26, 0.27" FFTW = "1.1" IteratorInterfaceExtensions = "1" KernelDensity = "0.6.2" -MCMCChains = "5" +MCMCChains = "6" MLJBase = "0.18, 0.19, 0.20, 0.21" -MLJDecisionTreeInterface = "=0.3.0" +MLJDecisionTreeInterface = "0.3" StatsBase = "0.33.2" StatsPlots = "0.14.17, 0.15" TableTraits = "1" diff --git a/test/ess_rhat_tests.jl b/test/ess_rhat_tests.jl new file mode 100644 index 00000000..2545c494 --- /dev/null +++ b/test/ess_rhat_tests.jl @@ -0,0 +1,55 @@ +using MCMCChains +using FFTW + +using Random +using Statistics +using Test + +@testset "ESS per second" begin + c1 = Chains(randn(100,5, 1), info = (start_time=time(), stop_time = time()+1)) + c2 = Chains(randn(100,5, 1), info = (start_time=time()+1, stop_time = time()+2)) + c = chainscat(c1, c2) + + wall = MCMCChains.wall_duration(c) + compute = MCMCChains.compute_duration(c) + + @test round(wall, digits=1) ≈ round(c2.info.stop_time - c1.info.start_time, digits=1) + @test compute ≈ (MCMCChains.compute_duration(c1) + MCMCChains.compute_duration(c2)) + + for f in (ess, ess_rhat) + s = f(c) + @test length(s[:,:ess_per_sec]) == 5 + @test all(map(!ismissing, s[:,:ess_per_sec])) + end +end + +@testset "ess/rhat/ess_rhat (chains)" begin + x = rand(10_000, 40, 10) + chain = Chains(x) + + for autocov_method in (AutocovMethod(), FFTAutocovMethod(), BDAAutocovMethod()), kind in (:bulk, :basic), f in (ess, ess_rhat, rhat) + # analyze chain + ess_df = ess(chain; autocov_method = autocov_method, kind = kind) + rhat_df = rhat(chain; kind = kind) + ess_rhat_df = ess_rhat(chain; autocov_method = autocov_method, kind = kind) + + # analyze array + ess_array, rhat_array = ess_rhat( + permutedims(x, (1, 3, 2)); autocov_method = autocov_method, kind = kind, + ) + @test ess_df[:,2] == ess_rhat_df[:,2] == ess_array + @test rhat_df[:,2] == ess_rhat_df[:,3] == rhat_array + end +end + +@testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed + val = rand(1, 5, 3) + chain = Chains(val, ["a", "b", "c", "d", "e"]) + + for autocov_method in (AutocovMethod(), FFTAutocovMethod(), BDAAutocovMethod()) + # analyze chain + @test_throws ArgumentError ess(chain; autocov_method = autocov_method) + @test_throws ArgumentError ess_rhat(chain; autocov_method = autocov_method) + end + @test all(isnan, rhat(chain)[:, 2]) +end diff --git a/test/ess_tests.jl b/test/ess_tests.jl deleted file mode 100644 index 82bf2e3a..00000000 --- a/test/ess_tests.jl +++ /dev/null @@ -1,54 +0,0 @@ -using MCMCChains -using FFTW - -using Random -using Statistics -using Test - -@testset "ESS per second" begin - c1 = Chains(randn(100,5, 1), info = (start_time=time(), stop_time = time()+1)) - c2 = Chains(randn(100,5, 1), info = (start_time=time()+1, stop_time = time()+2)) - c = chainscat(c1, c2) - - wall = MCMCChains.wall_duration(c) - compute = MCMCChains.compute_duration(c) - - @test round(wall, digits=1) ≈ round(c2.info.stop_time - c1.info.start_time, digits=1) - @test compute ≈ (MCMCChains.compute_duration(c1) + MCMCChains.compute_duration(c2)) - - s = ess_rhat(c) - @test length(s[:,:ess_per_sec]) == 5 - @test all(map(!ismissing, s[:,:ess_per_sec])) -end - -@testset "ESS and R̂ (chains)" begin - x = rand(10_000, 40, 10) - chain = Chains(x) - - for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod()) - # analyze chain - ess_df = ess_rhat(chain; method = method) - - # analyze array - ess_array, rhat_array = ess_rhat(permutedims(x, (1, 3, 2)); method = method) - - @test ess_df[:,2] == ess_array - @test ess_df[:,3] == rhat_array - end -end - -@testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed - val = rand(1, 5, 3) - chain = Chains(val, ["a", "b", "c", "d", "e"]) - - for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod()) - # analyze chain - ess_df = ess_rhat(chain; method = method) - - # analyze array - ess_array, rhat_array = ess_rhat(permutedims(val, (1, 3, 2)); method = method) - - @test ismissing(ess_df[:,2][1]) # since min(maxlag, niter - 1) = 0 - @test ismissing(ess_df[:,3][1]) - end -end diff --git a/test/mcse_tests.jl b/test/mcse_tests.jl new file mode 100644 index 00000000..fdbefbb8 --- /dev/null +++ b/test/mcse_tests.jl @@ -0,0 +1,34 @@ +using MCMCChains + +using Random +using Statistics +using Test + +mymean(x) = mean(x) + +@testset "mcse" begin + x = rand(10_000, 40, 10) + chain = Chains(x) + + for kind in (mean, std, mymean) + if kind !== mymean + for autocov_method in (AutocovMethod(), BDAAutocovMethod()) + # analyze chain + mcse_df = mcse(chain; autocov_method = autocov_method, kind = kind) + + # analyze array + mcse_array = mcse( + PermutedDimsArray(x, (1, 3, 2)); autocov_method = autocov_method, kind = kind, + ) + @test mcse_df[:,2] == mcse_array + end + else + # analyze chain + mcse_df = mcse(chain; kind = kind) + + # analyze array + mcse_array = mcse(PermutedDimsArray(x, (1, 3, 2)); kind = kind) + @test mcse_df[:,2] == mcse_array + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 9d6ea571..dfa37395 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,9 +8,17 @@ using Random Random.seed!(0) @testset "MCMCChains" begin - # run tests for effective sample size - println("ESS") - @time include("ess_tests.jl") + # run tests for effective sample size and R-hat + println("ESS/R-hat") + @time include("ess_rhat_tests.jl") + + # run tests for mcse + println("MCSE") + @time include("mcse_tests.jl") + + # run tests for tables interfaces + println("Tables interfaces") + @time include("tables_tests.jl") # run plotting tests println("Plotting") @@ -44,10 +52,6 @@ Random.seed!(0) println("Array") @time include("arrayconstructor_tests.jl") - # run tests for tables interfaces - println("Tables interfaces") - @time include("tables_tests.jl") - # run tests for dataframe summary println("Summary") @time include("summarize_tests.jl") diff --git a/test/summarize_tests.jl b/test/summarize_tests.jl index 6827cb18..741fccd3 100644 --- a/test/summarize_tests.jl +++ b/test/summarize_tests.jl @@ -18,16 +18,16 @@ using Statistics: std show(stdout, "text/plain", parm_df) @test 0.48 < parm_df[:a, :mean][1] < 0.52 - @test names(parm_df) == [:parameters, :mean, :std, :naive_se, :mcse, :ess, :rhat] + @test names(parm_df) == [:parameters, :mean, :std, :mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec] # Indexing tests - @test convert(Array, parm_df[:a, :]) == convert(Array, parm_df[:a]) + @test isequal(convert(Array, parm_df[:a, :]), convert(Array, parm_df[:a])) @test parm_df[:a, :][:,:parameters] == :a @test parm_df[[:a, :b], :][:,:parameters] == [:a, :b] all_sections_df = summarize(chns, sections=[:parameters, :internals]) @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h] - @test size(all_sections_df) == (8, 7) + @test size(all_sections_df) == (8, 8) two_parms_two_funs_df = summarize(chns[[:a, :b]], mean, std) @test two_parms_two_funs_df[:, :parameters] == [:a, :b] diff --git a/test/tables_tests.jl b/test/tables_tests.jl index 8383372d..ba3cb71c 100644 --- a/test/tables_tests.jl +++ b/test/tables_tests.jl @@ -125,7 +125,7 @@ using DataFrames @inferred DataFrame(chn) @test DataFrame(chn) isa DataFrame df = DataFrame(chn) - @test Tables.columntable(df) == Tables.columntable(chn) + @test isequal(Tables.columntable(df), Tables.columntable(chn)) end end @@ -144,7 +144,7 @@ using DataFrames @test Tables.columns(cdf) === cdf @test Tables.columnnames(cdf) == keys(cdf.nt) for (k, v) in pairs(cdf.nt) - @test Tables.getcolumn(cdf, k) == v + @test isequal(Tables.getcolumn(cdf, k), v) end @test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1]) @test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2]) @@ -163,22 +163,25 @@ using DataFrames row = rows[i] @test Tables.columnnames(row) == keys(cdf.nt) for j in length(cdf.nt) - @test Tables.getcolumn(row, j) == cdf.nt[j][i] - @test Tables.getcolumn(row, keys(cdf.nt)[j]) == cdf.nt[j][i] + @test isequal(Tables.getcolumn(row, j), cdf.nt[j][i]) + @test isequal(Tables.getcolumn(row, keys(cdf.nt)[j]), cdf.nt[j][i]) end end end @testset "integration tests" begin @test length(Tables.rowtable(cdf)) == length(cdf.nt[1]) - @test Tables.columntable(cdf) == cdf.nt + @test isequal(Tables.columntable(cdf), cdf.nt) nt = Tables.rowtable(cdf)[1] - @test nt == (; (k => v[1] for (k, v) in pairs(cdf.nt))...) - @test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1] + @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) + @test isequal(nt, collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1]) nt = Tables.rowtable(cdf)[2] - @test nt == (; (k => v[2] for (k, v) in pairs(cdf.nt))...) - @test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2] - @test Tables.matrix(Tables.rowtable(cdf)) == Tables.matrix(Tables.columntable(cdf)) + @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) + @test isequal(nt, collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2]) + @test isequal( + Tables.matrix(Tables.rowtable(cdf)), + Tables.matrix(Tables.columntable(cdf)), + ) end @testset "schema" begin @@ -192,16 +195,16 @@ using DataFrames @test IteratorInterfaceExtensions.isiterable(cdf) @test TableTraits.isiterabletable(cdf) nt = collect(Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 1))[1] - @test nt == (; (k => v[1] for (k, v) in pairs(cdf.nt))...) + @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) nt = collect(Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 2))[2] - @test nt == (; (k => v[2] for (k, v) in pairs(cdf.nt))...) + @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) end @testset "DataFrames.DataFrame constructor" begin @inferred DataFrame(cdf) df = DataFrame(cdf) @test df isa DataFrame - @test Tables.columntable(df) == cdf.nt + @test isequal(Tables.columntable(df), cdf.nt) end end end