Skip to content

Commit

Permalink
14: Add meta_analytic_samples wrapper (#17)
Browse files Browse the repository at this point in the history
* added wrapper

* get tests passing
  • Loading branch information
danielinteractive authored Mar 1, 2024
1 parent ceb3c13 commit fe52f13
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 49 deletions.
6 changes: 4 additions & 2 deletions src/SafetySignalDetection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ module SafetySignalDetection

using Turing
using StatsPlots
using DataFrames
using Distributions
using SpecialFunctions
using Statistics
using LinearAlgebra
using ExpectationMaximization

export
meta_analytic,
meta_analysis_model,
meta_analytic_samples,
fit_beta_mixture

include("meta_analytic.jl")
include("meta_analysis.jl")
include("fit_mle.jl")
include("fit_beta_mixture.jl")

Expand Down
85 changes: 85 additions & 0 deletions src/meta_analysis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Meta Analysis Model
This Turing model is used to generate posterior samples of the parameters `a` and `b`.
meta_analysis_model(
y::Vector{Bool},
time::Vector{Float64},
trialindex::Vector{Int64},
prior_a::Distribution,
prior_b::Distribution)
"""
@model function meta_analysis_model(
y::Vector{Bool},
time::Vector{Float64},
trialindex::Vector{Int64},
prior_a::Distribution,
prior_b::Distribution)

n = length(y)
n_trials = maximum(trialindex)

a ~ prior_a
b ~ prior_b

# Truncate the two Beta distribution parameters to be larger than zero to avoid
# initialization problems.
beta_par_min = floatmin(Float64)
first = max(beta_par_min, a * b * n)
second = max(beta_par_min, (1 - a) * b * n)

# Add one more pi parameter here to represent the new trial
# where we want to have the prior for.
pis ~ filldist(Beta(first, second), n_trials + 1)

for i in 1:n
pi = pis[trialindex[i]]
mu = log(-log(1 - pi))
x = mu + log(time[i])
prob = 1 - exp(-exp(x))
y[i] ~ Bernoulli(prob)
end

end;


"""
Meta Analytic Prior Samples Generation
This function wraps the Turing model `meta_analysis_model` and runs it for a data frame `df` with:
- `y`: Bool (did the adverse event occur?)
- `time`: Float64 (time until adverse event or until last treatment or follow up)
- `trialindex`: Int64 (index of trials, starting from 1 and consecutively numbered)
meta_analytic_samples(
df::DataFrame,
prior_a::Distribution,
prior_b::Distribution,
args...
)
Note that arguments for the number of samples per chain and the number of chains have to be passed as well.
It returns an array with the samples from the meta analytic prior (MAP).
"""
function meta_analytic_samples(
df::DataFrame,
prior_a::Distribution,
prior_b::Distribution,
args...
)
chain = sample(
meta_analysis_model(df.y, df.time, df.trialindex, prior_a, prior_b),
NUTS(0.65),
MCMCThreads(),
args...
)
n_trials = maximum(df.trialindex)
predictive_index = n_trials + 1
pi_star_name = "pis[" * string(predictive_index) * "]"
vec(chain[pi_star_name].data)

end
37 changes: 0 additions & 37 deletions src/meta_analytic.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ using Turing
using SafetySignalDetection

include("test_helpers.jl")
include("test_meta_analytic.jl")
include("test_meta_analysis.jl")
include("test_fit_beta_mixture.jl")
include("test_fit_mle.jl")
35 changes: 26 additions & 9 deletions test/test_meta_analytic.jl → test/test_meta_analysis.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "meta_analytic.jl" begin
@testset "Check that meta_analysis_model works as before" begin
rng = StableRNG(123)

n_trials = 5
Expand All @@ -11,7 +11,7 @@

chain = sample(
rng,
meta_analytic(df.y, df.time, df.trialindex, Beta(2, 8), Beta(9, 10)),
meta_analysis_model(df.y, df.time, df.trialindex, Beta(2, 8), Beta(9, 10)),
HMC(0.05, 10),
1000
)
Expand All @@ -20,7 +20,7 @@
check_numerical(chain, [:b], [0.485], rtol=0.001)
end

@testset "Reconcile meta_analytic with rstan on small historical dataset" begin
@testset "Reconcile meta_analysis_model with rstan on small historical dataset" begin
# Create MAP priors on known small historical dataset and compare to rstan

prior_a = Beta(1 / 3, 1 / 3)
Expand All @@ -32,8 +32,8 @@ end
rng = StableRNG(123)
map_small = sample(
rng,
meta_analytic(df_small.y, df_small.time, df_small.trial,
prior_a, prior_b),
meta_analysis_model(df_small.y, df_small.time, df_small.trial,
prior_a, prior_b),
NUTS(0.65),
10_000
)
Expand All @@ -42,7 +42,7 @@ end

end

@testset "Reconcile meta_analytic with rstan on large historical dataset" begin
@testset "Reconcile meta_analysis_model with rstan on large historical dataset" begin
# Create MAP priors on known large historical dataset and compare to rstan

prior_a = Beta(1 / 3, 1 / 3)
Expand All @@ -54,12 +54,29 @@ end
rng = StableRNG(123)
map_large = sample(
rng,
meta_analytic(df_large.y, df_large.time, df_large.trial,
prior_a, prior_b),
meta_analysis_model(df_large.y, df_large.time, df_large.trial,
prior_a, prior_b),
NUTS(0.65),
10_000
)
check_numerical(map_large, [:a], [0.13], rtol = 0.01)
check_numerical(map_large, [:b], [0.55], rtol = 0.01)

end
end

@testset "meta_analytic_samples runs as expected" begin
rng = StableRNG(123)

n_trials = 5
n_patients = 50
df = DataFrame(
y = rand(rng, Bernoulli(0.2), n_trials * n_patients),
time = rand(rng, Exponential(1), n_trials * n_patients),
trialindex = repeat(1:n_trials, n_patients)
)

samples = meta_analytic_samples(df, Beta(2, 8), Beta(9, 10), 100, 1)

@test typeof(samples) == Vector{Float64}
@test length(samples) == 100
end

0 comments on commit fe52f13

Please sign in to comment.