Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds build_paramDict #26

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
SciPyDiffEq = "505e40e9-d84e-434c-8501-7161967c02cb"
Expand Down
143 changes: 74 additions & 69 deletions scripts/dydt_generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,77 @@ using DifferentialEquations
using Plots
using BenchmarkTools
using Coexist
using Parameters

# initial state
nAge=9
nHS=8
nIso=4
nI=4
nTest=4
param=(nAge=nAge, nHS=nHS, nIso=nIso, nTest=nTest)
ssize=(nAge, nHS, nIso, nTest)
@with_kw mutable struct dydt_Complete
# initial state
nAge=9
nHS=8
nIso=4
nI=4
nTest=4
ssize=(nAge, nHS, nIso, nTest)
param=(nAge=nAge, nHS=nHS, nIso=nIso, nTest=nTest)
trFunc_diseaseProgression=Coexist.trFunc_diseaseProgression()
end

function (f::dydt_Complete)(stateTensor_flattened, p, t)
T = eltype(stateTensor_flattened)
stateTensor = reshape(stateTensor_flattened, (f.nTest, f.nIso, f.nHS, f.nAge))
dydt = zeros(T,size(stateTensor)...)
trTensor_complete = zeros(T,(f.nTest, f.nIso, f.nHS, f.nTest, f.nIso, f.nHS, f.nAge))
trTensor_diseaseProgression = f.trFunc_diseaseProgression(;f.param...)
for k in 1:4
shape = size(trTensor_diseaseProgression[:,k,:,:])
expand_dims = reshape(trTensor_diseaseProgression[:,k,:,:],
(shape[1:end-2]..., 1, shape[end-1:end]...)) # equal to np.exapand_dims
view = Coexist.einsum("ijlml->ijlm", trTensor_complete[:,k,:,:,k,:,:],T)
view .+= expand_dims
trTensor_complete[:,k,:,:,k,:,:] = Coexist._einsum12(trTensor_complete[:,k,:,:,k,:,:], view, T)
end

function dydt_Complete(stateTensor_flattened,p,t)
T = eltype(stateTensor_flattened)
stateTensor = reshape(stateTensor_flattened, (nTest, nIso, nHS, nAge))
dydt = zeros(T,size(stateTensor)...)
trTensor_complete = zeros(T,(nTest, nIso, nHS, nTest, nIso, nHS, nAge))
trTensor_diseaseProgression = Coexist.trFunc_diseaseProgression(;param...)
for k in 1:4
shape = size(trTensor_diseaseProgression[:,k,:,:])
expand_dims = reshape(trTensor_diseaseProgression[:,k,:,:],
(shape[1:end-2]..., 1, shape[end-1:end]...)) # equal to np.exapand_dims
# slice = trTensor_complete[:,k,:,:,k,:,:]
# @einsum view[m,l,j,i] := slice[l,m,l,j,i]
view = Coexist.einsum("ijlml->ijlm", trTensor_complete[:,k,:,:,k,:,:],T)
view .+= expand_dims
trTensor_complete[:,k,:,:,k,:,:] = Coexist._einsum12(trTensor_complete[:,k,:,:,k,:,:], view, T)
# @einsum slice[l,m,l,j,i] = view[m,l,j,i]
# trTensor_complete[:,k,:,:,k,:,:] = slice
end
to_be_modified = Coexist.einsum("ijkljkl->ijkl", trTensor_complete, T)
to_be_modified .-= Coexist.einsum("...jkl->...", trTensor_complete, T)
trTensor_complete = Coexist._einsum11(trTensor_complete, to_be_modified, T)

# @einsum to_be_modified[l,k,j,i] := trTensor_complete[l,k,j,l,k,j,i]
to_be_modified = Coexist.einsum("ijkljkl->ijkl", trTensor_complete, T)
to_be_modified .-= Coexist.einsum("...jkl->...", trTensor_complete, T)
# @einsum trTensor_complete[l,k,j,l,k,j,i] = view[l,k,j,i]
trTensor_complete = Coexist._einsum11(trTensor_complete, to_be_modified, T)
dydt = Coexist.einsum("ijkl,ijklmnp->imnp", stateTensor, trTensor_complete, T)
return vec(dydt)
end


dydt = Coexist.einsum("ijkl,ijklmnp->imnp", stateTensor, trTensor_complete, T)
return vec(dydt)
@with_kw mutable struct solveSystem
timeSpan=(0.0, 80.0)
p=nothing
dydt_Complete=dydt_Complete()
end

state = 50*ones(9*8*4*4)
dstate = similar(state)
function (f::solveSystem)(state, timeSpan=(0.0, 80.0), p=nothing)
f.timeSpan=timeSpan
f.p=p
println("solveSystem")
dstate = similar(state)

using ModelingToolkit
@variables t u[1:1152](t)
@derivatives D'~t
# @variables t u[1:1152](t)
# @derivatives D'~t
# dydt_symbolic = f.dydt_Complete(u,0,nothing)
# dydt_symbolic = simplify.(dydt_symbolic)
# sys = ODESystem(D.(u) .~ dydt_symbolic,t,u,[])
# computed_fn = generate_function(sys)[2]
#
# fn = eval(computed_fn)

dydt_symbolic = dydt_Complete(u,0,nothing)
dydt_symbolic = simplify.(dydt_symbolic)
sys = ODESystem(D.(u) .~ dydt_symbolic,t,u,[])
computed_f = generate_function(sys)[2]
## open(joinpath(@__DIR__,"..","src","generated_dydt.jl"), "w") do io
## write(io, "const dydt = $computed_f")
## end

open(joinpath(@__DIR__,"..","src","generated_dydt.jl"), "w") do io
write(io, "const dydt = $computed_f")
# prob = ODEProblem(fn,state,(0.0,80.0),p=nothing)
prob = ODEProblem(f.dydt_Complete,state,(0.0,80.0),p=nothing)
sol = solve(prob,Tsit5(),reltol=1e-3,abstol=1e-3)
sol = convert(Array, sol)
reshape(sol, (4,4,8,9, length(sol) ÷ (4*4*8*9)))
end

f = eval(computed_f)
state = 50*ones(9*8*4*4)

function solveSystem_unoptimized(
state,
Expand All @@ -72,28 +86,19 @@ function solveSystem_unoptimized(
return reshape(sol, (4,4,8,9, length(sol) ÷ (4*4*8*9)))
end

function solveSystem(
state,
timeSpan = (0.0, 80.0),
p=nothing
)
prob = ODEProblem(f,state,(0.0,80.0),p=nothing)
sol = solve(prob,Tsit5(),reltol=1e-3,abstol=1e-3)
sol = convert(Array, sol)
reshape(sol, (4,4,8,9, length(sol) ÷ (4*4*8*9)))
end

using BenchmarkTools
dydt_benchmark = @benchmark dydt_Complete(state,0,nothing)
f_benchmark = @benchmark f(dstate,state,0,nothing)
solveSystem_unoptimized_benchmark = @benchmark solveSystem_unoptimized(state)
solveSystem_benchmark = @benchmark solveSystem(state)
soln = solveSystem()(state)

println("BENCHMARK OF f")
display(dydt_benchmark)
display(f_benchmark)
println()
println("BENCHMARK OF solveSystem")
display(solveSystem_unoptimized_benchmark)
display(solveSystem_benchmark)
println()
# using BenchmarkTools
# dydt_benchmark = @benchmark dydt_Complete(state,0,nothing)
# f_benchmark = @benchmark f(dstate,state,0,nothing)
# solveSystem_unoptimized_benchmark = @benchmark solveSystem_unoptimized(state)
# solveSystem_benchmark = @benchmark solveSystem(state)
#
# println("BENCHMARK OF f")
# display(dydt_benchmark)
# display(f_benchmark)
# println()
# println("BENCHMARK OF solveSystem")
# display(solveSystem_unoptimized_benchmark)
# display(solveSystem_benchmark)
# println()
3 changes: 3 additions & 0 deletions src/Coexist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ using LinearAlgebra
using Dates
import StatsFuns: logistic, gammapdf
using PyCall
using Parameters

abstract type CType end

const DATA_DIR = joinpath(dirname(@__FILE__), "..", "data")

Expand Down
Loading