Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code to work for OU process #34

Draft
wants to merge 1 commit into
base: joshua/non-delayed
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 5 additions & 57 deletions src/hes5_ode_steady_state.jl
Original file line number Diff line number Diff line change
@@ -1,64 +1,12 @@
"""
Function which defines the HES5 ode system
Function which defines the model ode system
"""
function hes_ode!(du, u, p, t)
du[1] = p[5] * hill_function(u[2], p[1], p[2]) - p[3] * u[1]
du[2] = p[6] * u[1] - p[4] * u[2]
function model_ode!(du, u, p, t)
du[1] = -p[1] * u[1]
end

"""
Calculate the Hill function for a given protein molecule number, repression threshold, and hill coefficient.

# Arguments

- `protein::AbstractFloat`

- `P₀::AbstractFloat`

- `h::AbstractFloat`
"""
function hill_function(protein, P₀, h)
(P₀^h) / (protein^h + P₀^h)
end

"""
The partial derivative of the Hill function with respect to the protein molecule number.

# Arguments

- `protein::AbstractFloat`

- `P₀::AbstractFloat`

- `h::AbstractFloat`
"""
function ∂hill∂p(protein, P₀, h)
-(h * P₀^h * protein^(h - 1)) / (protein^h + P₀^h)^2
end

"""
Calculate the steady state of the Hes5 ODE system, for a specific set of parameters.

# Arguments

- `P₀::AbstractFloat`

- `h::AbstractFloat`

- `μₘ::AbstractFloat`

- `μₚ::AbstractFloat`

- `αₘ::AbstractFloat`

- `αₚ::AbstractFloat`

# Returns

- `steady_state_solution::Array{AbstractFloat,1}`: A 2-element array, giving the steady state for the mRNA and protein respectively.
"""
function calculate_steady_state_of_ode(model_params; initial_guess = [40.0, 5000.0])
prob = SteadyStateProblem(hes_ode!, initial_guess, model_params)
function calculate_steady_state_of_ode(model_params; initial_guess = [1.0])
prob = SteadyStateProblem(model_ode!, initial_guess, model_params)
nl_prob = NonlinearProblem(prob)
solve(nl_prob, DynamicSS(Tsit5())).u
end
42 changes: 14 additions & 28 deletions src/kalman_filter_alg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,9 @@ julia> distributions[1, :]
```
"""
function kalman_filter(data, model_params, measurement_variance; ode_solver = Tsit5())
# F in the paper
observation_transform = [0.0 1.0]
observation_transform = [1.0]
# Jacobian of the system
instant_jac = [
-model_params[3] 0.0
model_params[6] -model_params[4]
]
instant_jac = [-model_params[1]]

# initialise state space and distribution predictions
system_state, predicted_observation_distributions =
Expand All @@ -84,11 +80,10 @@ function kalman_filter(data, model_params, measurement_variance; ode_solver = Ts
system_state.next_time = time
system_state = prediction_step!(system_state, model_params, instant_jac; ode_solver)

system_state = update_step!(system_state, observation, measurement_variance, observation_transform)
# Record the predicted mean and variance for our likelihood
predicted_observation_distributions[observation_index + 1, :] .=
distribution_prediction(system_state, observation_transform, measurement_variance)

system_state = update_step!(system_state, observation, measurement_variance, observation_transform)
end
return system_state, predicted_observation_distributions
end
Expand Down Expand Up @@ -120,21 +115,20 @@ function state_space_initialisation(data, params, observation_transform, measure

# construct system state space
mean = steady_state
variance = diagm(mean .* [20.0, 100.0])
variance = mean * 20.0
system_state = SystemState(mean, variance, data[1, 1], data[2, 1])

# inital update step
update_step!(system_state, data[1, 2], measurement_variance, observation_transform)

# initialise distributions
predicted_observation_distributions = zeros(eltype(mean), first(size(data)), 2)
predicted_observation_distributions[1, :] .= distribution_prediction(system_state, observation_transform, measurement_variance)

# inital update step
update_step!(system_state, data[1, 2], measurement_variance, observation_transform)
return system_state, predicted_observation_distributions
end

function state_space_mean_RHS(du, u, p, t)
du[1] = -p[3] * u[1] + p[5] * hill_function(u[2], p[1], p[2])
du[2] = p[6] * u[1] - p[4] * u[2]
du[1] = -p[1] * u[1]
nothing
end

Expand All @@ -150,21 +144,13 @@ function predict_mean!(system_state, model_params; ode_solver)
system_state, mean_solution
end

function calculate_noise_variance(params, mean_solution, t)
[
(params[3] * first(mean_solution(t)))+(params[5] * hill_function(first(mean_solution(t)), params[1], params[2])) 0.0
0.0 (params[6] * first(mean_solution(t)))+(params[4] * last(mean_solution(t)))
]
end

"""
Predict state space variance to the next observation time index.
"""
function predict_variance!(system_state, mean_solution, model_params, instant_jac; ode_solver)
function predict_variance!(system_state, model_params, instant_jac; ode_solver)
# TODO this currently has to be nested since we don't pass current_mean as a parameter, is this possible / faster?
function variance_RHS(dvariance, current_variance, params, t)
variance_of_noise = calculate_noise_variance(params, mean_solution, t)
dvariance .= instant_jac * current_variance + current_variance * instant_jac' + variance_of_noise
function variance_RHS(dvariance, current_variance, p, t)
dvariance .= instant_jac .* current_variance + current_variance .* instant_jac' .+ p[2]
end

tspan = (system_state.current_time, system_state.next_time)
Expand All @@ -190,8 +176,8 @@ Obtain the Kalman filter prediction about a future observation, `rho_{t+Δt}` an

"""
function prediction_step!(system_state, params, instant_jac; ode_solver)
system_state, mean_solution = predict_mean!(system_state, params; ode_solver)
system_state = predict_variance!(system_state, mean_solution, params, instant_jac; ode_solver)
system_state, _ = predict_mean!(system_state, params; ode_solver)
system_state = predict_variance!(system_state, params, instant_jac; ode_solver)
# update current_time
system_state.current_time = system_state.next_time
system_state
Expand All @@ -208,7 +194,7 @@ end

function update_variance!(system_state, observation_transform, helper_inverse)
system_state.variance -=
system_state.variance * observation_transform' * observation_transform * system_state.variance * helper_inverse
system_state.variance .* observation_transform' .* observation_transform .* system_state.variance * helper_inverse
system_state
end

Expand Down
2 changes: 0 additions & 2 deletions src/log_likelihood.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ Returns
- `log_likelihood::AbstractFloat`.
"""
function calculate_log_likelihood(data, params, measurement_variance; ode_solver = Tsit5())
size(data, 2) == 2 || throw(ArgumentError("observation matrix must be N × 2"))
all(params .>= 0.0) || throw(ErrorException("all model parameters must be positive"))
_, distributions = kalman_filter(data, params, measurement_variance; ode_solver)
logpdf(MvNormal(distributions[:, 1], diagm(distributions[:, 2])), data[:, 2])
end