Skip to content

Commit

Permalink
Merge pull request #73 from TuringLang/ml/models
Browse files Browse the repository at this point in the history
  • Loading branch information
mileslucas authored Nov 20, 2021
2 parents 230452d + 0d24f7d commit ef42bd4
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 6 deletions.
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ makedocs(
"Home" => "index.md",
"Examples" => [
"Gaussian Shells" => "examples/shells.md",
"Correlated Gaussian" => "examples/correlated.md"
"Correlated Gaussian" => "examples/correlated.md",
"Eggbox" => "examples/eggbox.md",
],
"API/Reference" => "api.md"
],
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ Proposals.RSlice
Models
Models.GaussianShells
Models.CorrelatedGaussian
Models.Eggbox
```
86 changes: 86 additions & 0 deletions docs/src/examples/eggbox.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Eggbox

This example will explore the classic eggbox function using [`Models.Eggbox`](@ref).

## Setup

For this example, you'll need to add the following packages
```julia
julia>]add Distributions MCMCChains Measurements NestedSamplers StatsBase StatsPlots
```

```@setup eggbox
using AbstractMCMC
using Random
AbstractMCMC.setprogress!(false)
Random.seed!(8452)
```

## Define model

```@example eggbox
using NestedSamplers
model, logz = Models.Eggbox()
nothing; # hide
```

let's take a look at a couple of parameters to see what the log-likelihood surface looks like

```@example eggbox
using StatsPlots
x = range(0, 1, length=1000)
y = range(0, 1, length=1000)
logf = [model.loglike([xi, yi]) for yi in y, xi in x]
heatmap(
x, y, logf,
xlims=extrema(x),
ylims=extrema(y),
xlabel="x",
ylabel="y",
)
```

## Sample

```@example eggbox
using MCMCChains
using StatsBase
# using multi-ellipsoid for bounds
# using default rejection sampler for proposals
sampler = Nested(2, 500)
chain, state = sample(model, sampler; dlogz=0.01, param_names=["x", "y"])
# resample chain using statistical weights
chain_resampled = sample(chain, Weights(vec(chain[:weights])), length(chain));
nothing # hide
```

## Results

```@example eggbox
chain_resampled
```

```@example eggbox
marginalkde(chain[:x], chain[:y])
plot!(xlims=(0, 1), ylims=(0, 1), sp=2)
plot!(xlims=(0, 1), sp=1)
plot!(ylims=(0, 1), sp=3)
```

```@example eggbox
density(chain_resampled, xlims=(0, 1))
vline!(0.1:0.2:0.9, c=:black, ls=:dash, sp=1)
vline!(0.1:0.2:0.9, c=:black, ls=:dash, sp=2)
```

```@example eggbox
using Measurements
logz_est = state.logz ± state.logzerr
diff = logz_est - logz
println("logz: $logz")
println("estimate: $logz_est")
println("diff: $diff")
nothing # hide
```
7 changes: 4 additions & 3 deletions docs/src/examples/shells.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,14 @@ let's take a look at a couple of parameters to see what the likelihood surface l
using StatsPlots
x = range(-6, 6, length=1000)
y = range(-6, 6, length=1000)
y = range(-2.5, 2.5, length=1000)
logf = [model.loglike([xi, yi]) for yi in y, xi in x]
heatmap(
x, y, exp.(logf),
aspect_ratio=1,
xlims=extrema(x),
ylims=extrema(y),
xlabel="x",
ylabel="y",
size=(400, 400)
)
```

Expand All @@ -66,6 +64,9 @@ chain_resampled

```@example shells
marginalkde(chain[:x], chain[:y])
plot!(xlims=(-6, 6), ylims=(-2.5, 2.5), sp=2)
plot!(xlims=(-6, 6), sp=1)
plot!(ylims=(-2.5, 2.5), sp=3)
```

```@example shells
Expand Down
1 change: 1 addition & 0 deletions src/models/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ using LogExpFunctions

include("shells.jl")
include("correlated.jl")
include("eggbox.jl")

end # module
31 changes: 31 additions & 0 deletions src/models/eggbox.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
@doc raw"""
Models.Eggbox()
Eggbox/Egg carton likelihood function
```math
z(x, y) = \left[a + \cos\frac{x}{b} \cdot \cos\frac{x}{b} \right]^5
```
# Examples
```jldoctest
julia> model, lnZ = Models.Eggbox();
julia> lnZ
235.88
```
"""
function Eggbox()
tmax = 5π

# uniform prior from 0, 1
prior(X) = X
function loglike(X)
a = cos(tmax * (2 * first(X) - 1) / 2)
b = cos(tmax * (2 * last(X) - 1) / 2)
return (2 + a * b)^5
end

lnZ = 235.88 # where do we get this from??
return NestedModel(loglike, prior), lnZ
end
20 changes: 18 additions & 2 deletions test/models.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const test_bounds = [Bounds.Ellipsoid, Bounds.MultiEllipsoid]
const test_props = [Proposals.Rejection(), Proposals.RWalk(ratio=0.9, walks=50), Proposals.RStagger(ratio=0.9, walks=75), Proposals.Slice(slices=10), Proposals.RSlice()]
const test_props = [Proposals.Rejection(maxiter=Int(1e6)), Proposals.RWalk(ratio=0.9, walks=50), Proposals.RStagger(ratio=0.9, walks=75), Proposals.Slice(slices=10), Proposals.RSlice()]


@testset "$(nameof(bound)), $(nameof(typeof(proposal)))" for bound in test_bounds, proposal in test_props
Expand Down Expand Up @@ -67,7 +67,7 @@ const test_props = [Proposals.Rejection(), Proposals.RWalk(ratio=0.9, walks=50),
chain_res = sample(chain, Weights(vec(chain[:weights])), length(chain))

diff = state.logz - analytic_logz
atol = 5state.logzerr
atol = 6state.logzerr
if diff > atol
@warn "logz estimate is poor" bound proposal error = diff tolerance = atol
end
Expand All @@ -80,5 +80,21 @@ const test_props = [Proposals.Rejection(), Proposals.RWalk(ratio=0.9, walks=50),
@test ymodes[1] -1 atol = σ
@test ymodes[2] 1 atol = σ
end

@testset "Eggbox" begin
model, logz = Models.Eggbox()

sampler = Nested(2, 1000; bounds=bound, proposal=proposal)

chain, state = sample(rng, model, sampler; dlogz=0.1)

@test state.logz logz atol = 5state.logzerr

chain_res = sample(chain, Weights(vec(chain[:weights])), length(chain))
xmodes = sort!(findpeaks(chain_res[:, 1, 1])[1:5])
@test all(isapprox.(xmodes, 0.1:0.2:0.9, atol=0.2))
ymodes = sort!(findpeaks(chain_res[:, 2, 1])[1:5])
@test all(isapprox.(ymodes, 0.1:0.2:0.9, atol=0.2))
end
end

0 comments on commit ef42bd4

Please sign in to comment.