-
-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
from_cmdstan
#335
Comments
Integration with Stan is implemented directly in StanSample.jl. For an example, check out https://julia.arviz.org/ArviZ/stable/quickstart/#Plotting-with-Stan.jl-outputs .
You're welcome! |
Perfect, that looks like what I need. I quite like working with cmdstan directly on the command line so I'll probably try to extract the csv reading and conversion from StanSample.jl, which I see just uses |
I will comment that it is very normal to run Stan from CLI and then later read the samples and do analysis with programming (e.g. R/Python/Julia). |
This makes sense. We used to wrap Python ArviZ's |
Thanks, I'll take a closer look at I'd like to do a PR but it might be a while since the semester has just started, and I might have to come back to you for some guidance, if that is ok. |
Indeed it does. DataFrames.jl in particular. I think it would make the most sense to define But also, StanIO itself can probably make most of its dependencies weak dependencies using extensions as well. I'll open an issue there. Either way, I think the extension route makes the most sense here.
Of course! Happy to help. |
Hi @sethaxen, after a bit of tinkering, I now have a first basic implementation of this. I ended up not relying on I did take up a dependency on the My two main questions right now would be:
I didn't start a PR yet, because I wanted to ask your opinion on where to put this first. Here's what I got so far: using DelimitedFiles
using InferenceObjects: from_namedtuple
"""
readheader(file)
Read the header with variable names from a stan csv file.
"""
function readheader(file)
for line in eachline(file)
startswith(line, "#") && continue
return string.(split(line, ","))
end
end
"""
readfiles(files)
Read one or more Stan output csv files into an ndraws x nvars x nchains
array. Return a tuple containing the array and the variable names.
Assumes that all files have the same schema.
"""
function readfiles(files)
header = readheader(first(files))
values = stack(files) do file
arr, _ = readdlm(file, ','; comments=true, header=true)
arr
end
return values, header
end
readfiles(file::AbstractString) = readfiles([file])
"""
vardims_from_names(names)
Parse the dimensions from the variable names as contained in a Stan
output csv file. Return a dict with the dimensions for each variable.
A list of names such as `["a.1.1", "a.1.2", "a.2.1", "a2.2"]` would
yield a `Dict("a" => (1,1))`.
Assumes that the names are ordered by dims.
"""
function vardims_from_names(names)
res = Dict{String,NTuple}()
# this relies on dims being sorted in the csv
for name in reverse(names)
var, dims... = split(name, ".")
var = string(var)
haskey(res, var) && continue
res[var] = tuple(parse.(Int, dims)...)
end
return res
end
"""
varindexes(names)
Map variables to their position range in the list of names as
contained in the csv header.
"""
function varindexes(names)
vars = first.(split.(names, "."))
vind = map(unique(vars)) do var
from = findfirst(==(var), vars)
to = findlast(==(var), vars)
string(var) => from:to
end
return Dict(vind)
end
"""
to_namedtuple(arr, names)
Split and reshape the draws in `arr` according to the variable dimensions
parsed from `names` and return a named tuple
"""
function to_namedtuple(arr, names)
draws, nvars, chains = size(arr)
vind = varindexes(names)
vdim = vardims_from_names(names)
vars = string.(first.(split.(names, ".")))
res = map(unique(vars)) do var
a = arr[:, vind[var], :]
Symbol(var) => reshape(a, (draws, chains, vdim[var]...))
end
return (; res...)
end
## This is copied from the StanSample.jl extension
const SAMPLE_STATS_KEYMAP = (
n_leapfrog__=:n_steps,
treedepth__=:tree_depth,
energy__=:energy,
lp__=:lp,
stepsize__=:step_size,
divergent__=:diverging,
accept_stat__=:acceptance_rate,
)
function rekey(nt::NamedTuple, keymap)
new_keys = map(k -> get(keymap, k, k), keys(nt))
return NamedTuple{new_keys}(values(nt))
end
function split_post_stats(nt, keymap)
stats = filter(in(values(keymap)), keys(nt))
post = filter(!in(values(keymap)), keys(nt))
return NamedTuple{post}(nt), NamedTuple{stats}(nt)
end
is_file(arg::AbstractString) = endswith(arg, ".csv")
is_file(arg::AbstractVector{<:String}) = all(endswith(a, ".csv") for a in arg)
"""
from_cmdstan(posterior::Union{<:AbstractString,Vector{<:AbstractString}}; kwargs...)
Create an `InferenceData` from CmdStan csv files. `kwargs` can be filenames indicating CmdStan output,
such as prior draws or generated quantities, or named tuples. If they are files, the contained draws are
reshaped into (draws x chains x vardims...) arrays and are passed to `InferenceObjects.from_namedtuple`.
"""
function from_cmdstan(
posterior::Union{<:AbstractString,Vector{<:AbstractString}};
prior = nothing,
sample_stats_prior = nothing,
kwargs...
)
post, sample_stats = let
nt = to_namedtuple(readfiles(posterior)...)
nt = rekey(nt, SAMPLE_STATS_KEYMAP)
split_post_stats(nt, SAMPLE_STATS_KEYMAP)
end
if !isnothing(prior)
nt = to_namedtuple(readfiles(prior)...)
nt = rekey(nt, SAMPLE_STATS_KEYMAP)
prior, sample_stats_prior = split_post_stats(nt, SAMPLE_STATS_KEYMAP)
end
kwargs = map(NamedTuple(kwargs)) do arg
is_file(arg) || return arg
to_namedtuple(readfiles(arg)...)
end
return from_namedtuple(post; sample_stats, prior, sample_stats_prior, kwargs...)
end |
Is there something similar in
ArviZ.jl
tofrom_cmdstan()
in the python library? If not, is there a specific way to plug intoInferenceData
that you would recommend? Thanks for all your work in the ArviZ julia ecosystem, it seems like a great project!The text was updated successfully, but these errors were encountered: