Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert scalar data to 0-dimensional arrays #36

Merged
merged 13 commits into from
Dec 10, 2022
10 changes: 6 additions & 4 deletions .github/workflows/Subpackage_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ jobs:
shell: julia --project=monorepo {0}
run: |
using Pkg;
pkg"dev . ./lib/${{ matrix.subpackage }}"
Pkg.develop([Pkg.PackageSpec(; path="."),
Pkg.PackageSpec(; path="./lib/${{ matrix.subpackage }}")])
- name: Run the tests
continue-on-error: true
run: >
julia --color=yes --project=monorepo -e 'using Pkg; Pkg.test("${{ matrix.subpackage }}", coverage=true)'
run: |
using Pkg
Pkg.test("${{ matrix.subpackage }}"; coverage=true)
shell: julia --color=yes --project=monorepo {0}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: lib/${{ matrix.subpackage }}/src
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InferenceObjects"
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.2.5"
version = "0.2.6"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -10,7 +10,7 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"

[compat]
Compat = "3.46.0, 4.2.0"
DimensionalData = "0.20, 0.21, 0.22, 0.23"
DimensionalData = "0.23.1"
OffsetArrays = "1"
OrderedCollections = "1"
julia = "1.6"
Expand Down
4 changes: 2 additions & 2 deletions lib/InferenceObjectsNetCDF/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InferenceObjectsNetCDF"
uuid = "7cb6d088-77df-42c3-8f05-5ca8d42599d1"
authors = ["Seth Axen <[email protected]>"]
version = "0.2.3"
version = "0.2.4"

[deps]
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Expand All @@ -10,7 +10,7 @@ NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
DimensionalData = "0.20, 0.21, 0.22, 0.23"
DimensionalData = "0.23.1"
InferenceObjects = "0.2"
NCDatasets = "0.12"
Reexport = "1"
Expand Down
5 changes: 4 additions & 1 deletion lib/InferenceObjectsNetCDF/src/InferenceObjectsNetCDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ end

_var_to_array(var, load_mode) = var
function _var_to_array(var, load_mode::Val{:eager})
arr = Array(var)
arr = as_array(Array(var))
attr = var.attrib
try
arr_nomissing = NCDatasets.nomissing(arr)
Expand Down Expand Up @@ -187,4 +187,7 @@ function to_netcdf(data, ds::NCDatasets.NCDataset; group::Symbol=:posterior)
return ds
end

as_array(x) = fill(x)
as_array(x::AbstractArray) = x

end
4 changes: 3 additions & 1 deletion src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ end

Convert `NamedTuple` mapping variable names to arrays to a [`Dataset`](@ref).

Any non-array values will be converted to a 0-dimensional array.

# Keywords

- `attrs::AbstractDict{<:AbstractString}`: a collection of metadata to attach to the
Expand Down Expand Up @@ -79,7 +81,7 @@ function namedtuple_to_dataset(
default_dims=DEFAULT_SAMPLE_DIMS,
)
dim_arrays = map(keys(data)) do var_name
var_data = data[var_name]
var_data = as_array(data[var_name])
var_dims = get(dims, var_name, ())
return array_to_dimarray(var_data, var_name; dims=var_dims, coords, default_dims)
end
Expand Down
5 changes: 5 additions & 0 deletions src/from_namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ whose first dimensions correspond to the dimensions of the containers.

- `InferenceData`: The data with groups corresponding to the provided data

!!! note
If a `NamedTuple` is provided for `observed_data`, `constant_data`, or
predictions_constant_data`, any non-array values (e.g. integers) are converted to
0-dimensional arrays.

# Examples

```@example
Expand Down
5 changes: 4 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ ordered with dimensions of the innermost container first and outermost last.
recursive_stack(x) = x
recursive_stack(x::AbstractArray{<:AbstractArray}) = recursive_stack(stack(x))

as_array(x) = fill(x)
as_array(x::AbstractArray) = x

"""
namedtuple_of_arrays(x::NamedTuple) -> NamedTuple
namedtuple_of_arrays(x::AbstractArray{NamedTuple}) -> NamedTuple
Expand All @@ -25,7 +28,7 @@ ntarray = InferenceObjects.namedtuple_of_arrays(data);
```
"""
function namedtuple_of_arrays end
namedtuple_of_arrays(x::NamedTuple) = map(recursive_stack, x)
namedtuple_of_arrays(x::NamedTuple) = map(as_array ∘ recursive_stack, x)
namedtuple_of_arrays(x::AbstractArray) = namedtuple_of_arrays(namedtuple_of_arrays.(x))
function namedtuple_of_arrays(x::AbstractArray{<:NamedTuple{K}}) where {K}
return NamedTuple{K}(recursive_stack(getproperty.(x, k)) for k in K)
Expand Down
7 changes: 7 additions & 0 deletions test/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,12 @@ using InferenceObjects, DimensionalData, Test
@test metadata["inference_library"] == "MyLib"
@test !haskey(metadata, "inference_library_version")
@test metadata["mykey"] == 5

ds2 = namedtuple_to_dataset((x=1, y=randn(10)); default_dims=())
@test ds2 isa Dataset
@test ds2.x isa DimensionalData.DimArray{<:Any,0}
@test DimensionalData.dims(ds2.x) == ()
@test ds2.y isa DimensionalData.DimArray{<:Any,1}
@test DimensionalData.dims(ds2.y) == (Dim{:y_dim_1}(1:10),)
end
end
12 changes: 8 additions & 4 deletions test/from_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,22 @@ using InferenceObjects, OrderedCollections, Test
library = "MyLib"
dims = (; w=[:wx])
coords = (; wx=1:2)
idata1 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library)
idata1 = from_dict(
dict; group => Dict(:w => [1.0, 2.0], :v => 2.5), dims, coords, library
)
test_idata_group_correct(idata1, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata1, group, (:w,); library, dims, coords, default_dims=()
idata1, group, (:w, :v); library, dims, coords, default_dims=()
)

# ensure that dims are matched to named tuple keys
# https://github.com/arviz-devs/ArviZ.jl/issues/96
idata2 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library)
idata2 = from_dict(
dict; group => Dict(:w => [1.0, 2.0], :v => 2.5), dims, coords, library
)
test_idata_group_correct(idata2, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata2, group, (:w,); library, dims, coords, default_dims=()
idata2, group, (:w, :v); library, dims, coords, default_dims=()
)
end
end
8 changes: 4 additions & 4 deletions test/from_namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ using InferenceObjects, Test
library = "MyLib"
dims = (; w=[:wx])
coords = (; wx=1:2)
idata1 = from_namedtuple(nt; group => (w=[1.0, 2.0],), dims, coords, library)
idata1 = from_namedtuple(nt; group => (w=[1.0, 2.0], v=2.5), dims, coords, library)
test_idata_group_correct(idata1, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata1, group, (:w,); library, dims, coords, default_dims=()
idata1, group, (:w, :v); library, dims, coords, default_dims=()
)

# ensure that dims are matched to named tuple keys
# https://github.com/arviz-devs/ArviZ.jl/issues/96
idata2 = from_namedtuple(nt; group => (w=[1.0, 2.0],), dims, coords, library)
idata2 = from_namedtuple(nt; group => (w=[1.0, 2.0], v=2.5), dims, coords, library)
test_idata_group_correct(idata2, :posterior, keys(sizes); library, dims, coords)
test_idata_group_correct(
idata2, group, (:w,); library, dims, coords, default_dims=()
idata2, group, (:w, :v); library, dims, coords, default_dims=()
)
end
end
4 changes: 2 additions & 2 deletions test/test_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ random_dataset(args...) = Dataset(random_dim_stack(args...))

function random_data()
var_names = (:a, :b)
data_names = (:y,)
data_names = (:y, :v)
stats_names = (:diverging, :energy, :n_steps)
stats_eltypes = (diverging=Bool, n_steps=Int)
coords = (
chain=1:4, draw=1:100, shared=["s1", "s2", "s3"], dima=1:4, dimb=2:6, dimy=1:5
)
dims = (a=(:shared, :dima), b=(:shared, :dimb), y=(:shared, :dimy))
dims = (a=(:shared, :dima), b=(:shared, :dimb), y=(:shared, :dimy), v=())
metadata = Dict{String,Any}("inference_library" => "PPL")
posterior = random_dataset(var_names, dims, coords, metadata, (;))
posterior_predictive = random_dataset(data_names, dims, coords, metadata, (;))
Expand Down
14 changes: 13 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,20 @@ module TestSubModule end
@test InferenceObjects.recursive_stack(1:5) == 1:5
end

@testset "as_array" begin
@test InferenceObjects.as_array(3) == fill(3)
@test InferenceObjects.as_array(2.5) == fill(2.5)
@test InferenceObjects.as_array("var") == fill("var")
x = randn(3)
@test InferenceObjects.as_array(x) == x
x = randn(2, 3)
@test InferenceObjects.as_array(x) == x
x = randn(2, 3, 4)
@test InferenceObjects.as_array(x) == x
end

@testset "namedtuple_of_arrays" begin
@test InferenceObjects.namedtuple_of_arrays((x=3, y=4)) === (x=3, y=4)
@test InferenceObjects.namedtuple_of_arrays((x=3, y=4)) == (x=fill(3), y=fill(4))
@test InferenceObjects.namedtuple_of_arrays([(x=3, y=4), (x=5, y=6)]) ==
(x=[3, 5], y=[4, 6])
@test InferenceObjects.namedtuple_of_arrays([
Expand Down