From 2b3796d2b864f530b3581b0dbb9b336ad4b0244c Mon Sep 17 00:00:00 2001 From: burtonjosh Date: Sun, 25 Feb 2024 22:39:13 +0000 Subject: [PATCH] WIP reimplement for the OU process. This should give an idea about how to refactor the code to accomodate for generic models --- src/hes5_ode_steady_state.jl | 62 +++--------------------------------- src/kalman_filter_alg.jl | 42 ++++++++---------------- src/log_likelihood.jl | 2 -- 3 files changed, 19 insertions(+), 87 deletions(-) diff --git a/src/hes5_ode_steady_state.jl b/src/hes5_ode_steady_state.jl index f51ec9f..c5bc35a 100644 --- a/src/hes5_ode_steady_state.jl +++ b/src/hes5_ode_steady_state.jl @@ -1,64 +1,12 @@ """ -Function which defines the HES5 ode system +Function which defines the model ode system """ -function hes_ode!(du, u, p, t) - du[1] = p[5] * hill_function(u[2], p[1], p[2]) - p[3] * u[1] - du[2] = p[6] * u[1] - p[4] * u[2] +function model_ode!(du, u, p, t) + du[1] = -p[1] * u[1] end -""" -Calculate the Hill function for a given protein molecule number, repression threshold, and hill coefficient. - -# Arguments - -- `protein::AbstractFloat` - -- `P₀::AbstractFloat` - -- `h::AbstractFloat` -""" -function hill_function(protein, P₀, h) - (P₀^h) / (protein^h + P₀^h) -end - -""" -The partial derivative of the Hill function with respect to the protein molecule number. - -# Arguments - -- `protein::AbstractFloat` - -- `P₀::AbstractFloat` - -- `h::AbstractFloat` -""" -function ∂hill∂p(protein, P₀, h) - -(h * P₀^h * protein^(h - 1)) / (protein^h + P₀^h)^2 -end - -""" -Calculate the steady state of the Hes5 ODE system, for a specific set of parameters. - -# Arguments - -- `P₀::AbstractFloat` - -- `h::AbstractFloat` - -- `μₘ::AbstractFloat` - -- `μₚ::AbstractFloat` - -- `αₘ::AbstractFloat` - -- `αₚ::AbstractFloat` - -# Returns - -- `steady_state_solution::Array{AbstractFloat,1}`: A 2-element array, giving the steady state for the mRNA and protein respectively. -""" -function calculate_steady_state_of_ode(model_params; initial_guess = [40.0, 5000.0]) - prob = SteadyStateProblem(hes_ode!, initial_guess, model_params) +function calculate_steady_state_of_ode(model_params; initial_guess = [1.0]) + prob = SteadyStateProblem(model_ode!, initial_guess, model_params) nl_prob = NonlinearProblem(prob) solve(nl_prob, DynamicSS(Tsit5())).u end diff --git a/src/kalman_filter_alg.jl b/src/kalman_filter_alg.jl index 5b9b23f..230ecdf 100644 --- a/src/kalman_filter_alg.jl +++ b/src/kalman_filter_alg.jl @@ -66,13 +66,9 @@ julia> distributions[1, :] ``` """ function kalman_filter(data, model_params, measurement_variance; ode_solver = Tsit5()) - # F in the paper - observation_transform = [0.0 1.0] + observation_transform = [1.0] # Jacobian of the system - instant_jac = [ - -model_params[3] 0.0 - model_params[6] -model_params[4] - ] + instant_jac = [-model_params[1]] # initialise state space and distribution predictions system_state, predicted_observation_distributions = @@ -84,11 +80,10 @@ function kalman_filter(data, model_params, measurement_variance; ode_solver = Ts system_state.next_time = time system_state = prediction_step!(system_state, model_params, instant_jac; ode_solver) + system_state = update_step!(system_state, observation, measurement_variance, observation_transform) # Record the predicted mean and variance for our likelihood predicted_observation_distributions[observation_index + 1, :] .= distribution_prediction(system_state, observation_transform, measurement_variance) - - system_state = update_step!(system_state, observation, measurement_variance, observation_transform) end return system_state, predicted_observation_distributions end @@ -120,21 +115,20 @@ function state_space_initialisation(data, params, observation_transform, measure # construct system state space mean = steady_state - variance = diagm(mean .* [20.0, 100.0]) + variance = mean * 20.0 system_state = SystemState(mean, variance, data[1, 1], data[2, 1]) + # inital update step + update_step!(system_state, data[1, 2], measurement_variance, observation_transform) + # initialise distributions predicted_observation_distributions = zeros(eltype(mean), first(size(data)), 2) predicted_observation_distributions[1, :] .= distribution_prediction(system_state, observation_transform, measurement_variance) - - # inital update step - update_step!(system_state, data[1, 2], measurement_variance, observation_transform) return system_state, predicted_observation_distributions end function state_space_mean_RHS(du, u, p, t) - du[1] = -p[3] * u[1] + p[5] * hill_function(u[2], p[1], p[2]) - du[2] = p[6] * u[1] - p[4] * u[2] + du[1] = -p[1] * u[1] nothing end @@ -150,21 +144,13 @@ function predict_mean!(system_state, model_params; ode_solver) system_state, mean_solution end -function calculate_noise_variance(params, mean_solution, t) - [ - (params[3] * first(mean_solution(t)))+(params[5] * hill_function(first(mean_solution(t)), params[1], params[2])) 0.0 - 0.0 (params[6] * first(mean_solution(t)))+(params[4] * last(mean_solution(t))) - ] -end - """ Predict state space variance to the next observation time index. """ -function predict_variance!(system_state, mean_solution, model_params, instant_jac; ode_solver) +function predict_variance!(system_state, model_params, instant_jac; ode_solver) # TODO this currently has to be nested since we don't pass current_mean as a parameter, is this possible / faster? - function variance_RHS(dvariance, current_variance, params, t) - variance_of_noise = calculate_noise_variance(params, mean_solution, t) - dvariance .= instant_jac * current_variance + current_variance * instant_jac' + variance_of_noise + function variance_RHS(dvariance, current_variance, p, t) + dvariance .= instant_jac .* current_variance + current_variance .* instant_jac' .+ p[2] end tspan = (system_state.current_time, system_state.next_time) @@ -190,8 +176,8 @@ Obtain the Kalman filter prediction about a future observation, `rho_{t+Δt}` an """ function prediction_step!(system_state, params, instant_jac; ode_solver) - system_state, mean_solution = predict_mean!(system_state, params; ode_solver) - system_state = predict_variance!(system_state, mean_solution, params, instant_jac; ode_solver) + system_state, _ = predict_mean!(system_state, params; ode_solver) + system_state = predict_variance!(system_state, params, instant_jac; ode_solver) # update current_time system_state.current_time = system_state.next_time system_state @@ -208,7 +194,7 @@ end function update_variance!(system_state, observation_transform, helper_inverse) system_state.variance -= - system_state.variance * observation_transform' * observation_transform * system_state.variance * helper_inverse + system_state.variance .* observation_transform' .* observation_transform .* system_state.variance * helper_inverse system_state end diff --git a/src/log_likelihood.jl b/src/log_likelihood.jl index 01906fd..fc02af4 100644 --- a/src/log_likelihood.jl +++ b/src/log_likelihood.jl @@ -21,8 +21,6 @@ Returns - `log_likelihood::AbstractFloat`. """ function calculate_log_likelihood(data, params, measurement_variance; ode_solver = Tsit5()) - size(data, 2) == 2 || throw(ArgumentError("observation matrix must be N × 2")) - all(params .>= 0.0) || throw(ErrorException("all model parameters must be positive")) _, distributions = kalman_filter(data, params, measurement_variance; ode_solver) logpdf(MvNormal(distributions[:, 1], diagm(distributions[:, 2])), data[:, 2]) end