diff --git a/Project.toml b/Project.toml index a5bbf031..53f8b251 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +[weakdeps] +Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" + +[extensions] +MCMCChainsMakieExt = "Makie" + [compat] AbstractMCMC = "0.4, 0.5, 1.0, 2.0, 3.0, 4, 5" AxisArrays = "0.4.4" diff --git a/ext/MCMCChainsMakieExt.jl b/ext/MCMCChainsMakieExt.jl new file mode 100644 index 00000000..fe7a0d24 --- /dev/null +++ b/ext/MCMCChainsMakieExt.jl @@ -0,0 +1,116 @@ +module MCMCChainsMakieExt + +import MCMCChains +import Makie + + +function MCMCChains.trace(chns::T; figure=(;), colormap=Makie.to_colormap(:tol_vibrant)) where {T<:MCMCChains.Chains} + params = MCMCChains.names(chns, :parameters) + + n_chains = length(MCMCChains.chains(chns)) + n_samples = length(chns) + n_params = length(params) + + + colormap = if colormap isa Symbol + Makie.to_colormap(colormap) + else + colormap + end + @show length(colormap) + colorindex(i) = + mod(i - 1, length(colormap)) + 1 + + # set size if not provided + figure = let + width = 600 + height = max(400, 80 * n_params) + nt = (size=(width, height),) + merge(nt, figure) + end + + fig = Makie.Figure(; figure...) + + for (i, param) in enumerate(params) + ax = Makie.Axis(fig[i+1, 1]; ylabel=string(param)) + for chain in 1:n_chains + values = chns[:, param, chain] + Makie.lines!( + ax, + 1:n_samples, + values; + label=string(chain), + color=(colormap[colorindex(chain)], 0.7), + linewidth=0.7 + ) + end + + Makie.hideydecorations!(ax; label=false) + if i < n_params + Makie.hidexdecorations!(ax; grid=false) + else + ax.xlabel = "Iteration" + end + end + + for (i, param) in enumerate(params) + ax = Makie.Axis(fig[i+1, 2]; ylabel=string(param)) + for chain in 1:n_chains + values = chns[:, param, chain] + Makie.density!( + ax, + values; + label=string(chain), + color=(colormap[colorindex(chain)], 0.1), + strokewidth=1, + strokecolor=(colormap[colorindex(chain)], 0.7) + ) + end + + Makie.hideydecorations!(ax) + if i < n_params + Makie.hidexdecorations!(ax; grid=false) + else + ax.xlabel = "Parameter estimate" + end + end + + axes = [only(Makie.contents(fig[i+1, 2])) for i in 1:n_params] + Makie.linkxaxes!(axes...) + + Makie.Legend(fig[1, 1:2], first(axes), "Chain", orientation=:horizontal, titlehalign=:left, halign=:left, titleposition=:left) + + Makie.rowgap!(fig.layout, 10) + Makie.colgap!(fig.layout, 10) + + return fig +end + +# https://docs.makie.org/stable/explanations/specapi#convert_arguments-for-GridLayoutSpec +import Makie.SpecApi as S + +# Our custom type we want to write a conversion method for +struct PlotGrid + nplots::Tuple{Int,Int} +end + +# If we want to use the `color` attribute in the conversion, we have to +# mark it via `used_attributes` +Makie.used_attributes(::T) where {T<:MCMCChains.Chains} = (:linewidth, :alpha) + +# The conversion method creates a grid of `Axis` objects with `Lines` plot inside +# We restrict to Plot{plot}, so that only `plot(PlotGrid(...))` works, but not e.g. `scatter(PlotGrid(...))`. +function Makie.convert_arguments(::Type{Makie.Plot{Makie.plot}}, chn::T; linewidth=0.7, alpha=0.6) where {T<:MCMCChains.Chains} + n_iterations, n_params, n_chains = size(chn) + axes_left = [ + S.Axis(plots=[S.Lines(chn.value[:, p, i]; linewidth, alpha, label=string(i)) for i in 1:n_chains]) + for p in 1:n_params + ] + axes_right = [ + S.Axis(plots=[S.Density(chn.value[:, p, i]; alpha, label=string(i)) for i in 1:n_chains]) + for p in 1:n_params + ] + return S.GridLayout(hcat(axes_left, axes_right)) +end + +end \ No newline at end of file diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index dc7b8784..85e43c43 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -47,6 +47,7 @@ export rafterydiag export rstar export hpd +export trace """ Chains @@ -85,4 +86,20 @@ include("plot.jl") include("tables.jl") include("rstar.jl") + +### trace function via Extensions still needs to define a base function +function trace end + +function __init__() + + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs + if exc.f in [trace] + if isempty(methods(exc.f)) + print(io, "\n$(exc.f) has no methods, yet. Makie has to be loaded for the plotting extension to be activated. Run `using Makie`, `using CairoMakie`, `using GLMakie` or any other package that also loads Makie.") + end + end + end +end + + end # module