diff --git a/Project.toml b/Project.toml index 5da4fc4..c2f5171 100644 --- a/Project.toml +++ b/Project.toml @@ -14,8 +14,14 @@ Fleck = "5bb9b785-358c-4fee-af0f-b94a146244a8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" [compat] AlgebraicPetri = "0.9" +AlgebraicRewriting = "0.3" +Reexport = "1" +SpecialFunctions = "2" +DataStructures = "0.18" +Distributions = "0.25" Catlab = "^0.16" julia = "1.6" diff --git a/docs/Project.toml b/docs/Project.toml index 18eb94c..377a12c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ AlgebraicPetri = "4f99eebe-17bf-4e98-b6a1-2c4f205a959b" AlgebraicRewriting = "725a01d3-f174-5bbd-84e1-b9417bad95d9" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" diff --git a/src/ABMs.jl b/src/ABMs.jl index fac5eda..15ed6c5 100644 --- a/src/ABMs.jl +++ b/src/ABMs.jl @@ -1,18 +1,20 @@ module ABMs -export ABM +export ABM, ABMRule, run!, DiscreteHazard, ContinuousHazard, FullClosure, + ClosureState, ClosureTime using Distributions, Fleck, Random using DataStructures: DefaultDict +using StructEquality using Catlab, AlgebraicRewriting using AlgebraicPetri: AbstractReactionNet using AlgebraicRewriting.Incremental: connected_acset_components, key_dict using AlgebraicRewriting.Rewrite.Migration: pres_hash +using AlgebraicRewriting.Rewrite.Utils: get_pmap, get_rmap import Catlab: acset_schema, right, is_isomorphic -import ..PetriInterface: run! -import AlgebraicRewriting: get_match +import AlgebraicRewriting: get_match, ruletype import AlgebraicRewriting.Incremental: addition!, deletion! # Possibly upstream @@ -47,7 +49,9 @@ end Something that can produce a ACSetTransformation × clocktime → hazard_rate """ abstract type AbsTimer end -abstract type StateDependentTimer <: AbsTimer end + +abstract type StateDependentTimer <: AbsTimer end + state_dep(t::AbsTimer) = t isa StateDependentTimer """ @@ -57,6 +61,7 @@ clocktime → hazard_rate struct FullClosure <: StateDependentTimer val::Function # ACSetTransformation → clocktime → hazard_rate end + (c::FullClosure)(m::ACSetTransformation, t::Float64) = c.val(m,t) """ @@ -66,6 +71,7 @@ which cannot depend on the match data nor ACSet state. struct ClosureTime <: AbsTimer val::Function # clocktime → hazard_rate end + (c::ClosureTime)(t::Float64) = c.val(t) """ @@ -75,25 +81,29 @@ timer which cannot depend on the absolute clock time. struct ClosureState <: StateDependentTimer val::Function # clocktime → hazard_rate end + (c::ClosureState)(m::ACSetTransformation) = c.val(m) abstract type AbsHazard <: AbsTimer end -struct DiscreteHazard <: AbsHazard +@struct_hash_equal struct DiscreteHazard <: AbsHazard val::Distribution{Univariate, Discrete} end DiscreteHazard(t::Number) = DiscreteHazard(Dirac(t)) -struct ContinuousHazard <: AbsHazard +@struct_hash_equal struct ContinuousHazard <: AbsHazard val::Distribution{Univariate, Continuous} end ContinuousHazard(p::Number) = ContinuousHazard(Exponential(p)) get_hazard(m::ACSetTransformation, t::Float64, h::FullClosure) = h(m,t) + get_hazard(::ACSetTransformation, t::Float64, h::ClosureTime) = h(t) + get_hazard(m::ACSetTransformation, ::Float64, h::ClosureState) = h(m) + get_hazard(::ACSetTransformation, ::Float64, h::AbsHazard) = h.val # Rules @@ -101,14 +111,14 @@ get_hazard(::ACSetTransformation, ::Float64, h::AbsHazard) = h.val abstract type PatternType end """Empty patterns have (one) trivial pattern match""" -struct EmptyP <: PatternType end +@struct_hash_equal struct EmptyP <: PatternType end """ Default case, where pattern matches should be found via (incremental) homomorphism search and represented explicitly, each with own events getting scheduled. """ -struct RegularP <: PatternType end +@struct_hash_equal struct RegularP <: PatternType end # """ # Special case of homsearch where no backtracking is needed. The only nonempty @@ -120,7 +130,7 @@ struct RegularP <: PatternType end # WARNING: this is only viable if the timer associated with the rewrite rule is # symmteric with respect to the discrete parts. # """ -# struct DiscreteP <: PatternType +# @struct_hash_equal struct DiscreteP <: PatternType # parts::Dict{Symbol, Int} # end @@ -136,9 +146,10 @@ colimit leg sends 1, as there is often just one X part in the representable X). WARNING: this is only viable if the timer associated with the rewrite rule is symmteric with respect to the disjoint representables. """ -struct RepresentableP <: PatternType +@struct_hash_equal struct RepresentableP <: PatternType parts::Dict{Symbol, Vector{Int}} -end +end + Base.keys(p::RepresentableP) = keys(p.parts) """ @@ -181,26 +192,32 @@ end """ A stochastic rewrite rule with a dependent hazard rate """ -struct ABMRule +@struct_hash_equal struct ABMRule rule::Rule timer::AbsTimer pattern_type::PatternType ABMRule(r::Rule, t::AbsTimer) = new(r, t, pattern_type(r)) end + getrule(r::ABMRule) = r.rule + pattern_type(r::ABMRule) = r.pattern_type + pattern(r::ABMRule) = pattern(getrule(r)) + right(r::ABMRule) = right(getrule(r)) +ruletype(r::ABMRule) = ruletype(getrule(r)) + abstract type AbsDynamics end """Use a petri net with rates""" -struct PetriDynamics <: AbsDynamics +@struct_hash_equal struct PetriDynamics <: AbsDynamics val::AbstractReactionNet end """Continuous dynamics""" -struct ABMFlow +@struct_hash_equal struct ABMFlow pat::ACSet dyn::AbsDynamics mapping::Vector{Pair{Symbol, Int}} # pair pat's variables w/ dyn quantities @@ -209,22 +226,31 @@ end """ An agent-based model. """ -struct ABM +@struct_hash_equal struct ABM rules::Vector{ABMRule} dyn::Vector{ABMFlow} ABM(rules, dyn=[]) = new(rules, dyn) end + additions(abm::ABM) = right.(abm.rules) """A collection of timers associated at runtime w/ an ABMRule""" -abstract type AbsHomSet end -struct EmptyHomSet <: AbsHomSet end -struct DiscreteHomSet <: AbsHomSet end -struct ExplicitHomSet <: AbsHomSet val::IncHomSet end +abstract type AbsHomSet end + +@struct_hash_equal struct EmptyHomSet <: AbsHomSet end + +@struct_hash_equal struct DiscreteHomSet <: AbsHomSet end + +@struct_hash_equal struct ExplicitHomSet <: AbsHomSet val::IncHomSet end + Base.keys(h::ExplicitHomSet) = keys(h.val) + Base.pairs(h::ExplicitHomSet) = pairs(h.val) + Base.getindex(h::ExplicitHomSet, i) = h.val[i] + deletion!(h::ExplicitHomSet, m) = deletion!(h.val, m) + addition!(h::ExplicitHomSet, k, r, u) = addition!(h.val, k, r, u) """Initialize runtime hom-set given the rule and the initial state""" @@ -236,15 +262,17 @@ function init_homset(rule::ABMRule, state::ACSet, additions::Vector{<:ACSetTrans return DiscreteHomSet() end -const KeyType = Union{Pair{Int,Int}} # connected component homset - Tuple{Int,Vector{Pair{Int,Int}}} # multi-component homset +const Maybe{T} = Union{Nothing, T} +const KeyType = Union{Pair{Int, Int}} # connected comp. homset + Tuple{Int,Vector{Pair{Int,Int}}} # multi-component homset + const default_sampler = FirstToFire{ Union{Pair{Int, Nothing}, # non-explicit homset Pair{Int, KeyType}}, # explicit homset Float64} """ -Data structure for maintaining simulation information while running an ABM +Data @struct_hash_equal structure for maintaining simulation information while running an ABM """ mutable struct RuntimeABM state::ACSet @@ -268,13 +296,17 @@ mutable struct RuntimeABM end state(r::RuntimeABM) = r.state + Base.haskey(rt::RuntimeABM, k) = haskey(rt.sampler.transition_entry, k) +Base.haskey(rt::RuntimeABM, k::Int) = + haskey(rt.sampler.transition_entry, k=>nothing) + """Pick the next random event, advance the clock""" function Fleck.next(rt::RuntimeABM) rt.nevent += 1 (rt.tnow, which) = next(rt.sampler, rt.tnow, rt.rng) - which + return which end @@ -284,22 +316,26 @@ Get match returns a randomly chosen morphism for the aggregate rule TODO incorporate the number of possibilities as a multiplier for the rate """ get_match(::EmptyP, L::ACSet, G::ACSet, ::EmptyHomSet, ::Nothing) = create(G) -function get_match(P::RepresentableP, L::T, G::ACSet, ::DiscreteHomSet, ::Nothing) where T<:ACSet + +function get_match(P::RepresentableP, L::T, G::ACSet, ::DiscreteHomSet, + ::Nothing) where T<:ACSet initial = Dict(map(collect(pairs(P.parts))) do (o, idxs) o => Dict(idx => rand(parts(G, o)) for idx in idxs) end) - homomorphism(L, G; initial) + return homomorphism(L, G; initial) end + get_match(::RegularP, ::ACSet, ::ACSet, hs::ExplicitHomSet, key::KeyType) = hs[key] """ A trajectory of an ABM: each event time and result of `save`. """ -struct Traj +@struct_hash_equal struct Traj init::ACSet events::Vector{Tuple{Float64, Int, Any}} hist::Vector{Span{<:ACSet}} end + Traj(x::ACSet) = Traj(x, Pair{Float64, Any}[], Span{ACSet}[]) function Base.push!(t::Traj, τ::Float64,rule::Int, v::Any, sp::Span{<:ACSet}) @@ -314,64 +350,82 @@ Base.isempty(t::Traj) = isempty(t.events) Base.length(t::Traj) = length(t.events) -const MAXEVENT = 10 +const MAXEVENT = 100 """ -Run an ABM, creating a fresh trajectory. +Run an ABM, creating a fresh runtime + trajectory. """ function run!(abm::ABM, init::T; save=deepcopy, maxevent=MAXEVENT, maxtime=Inf, kw...) where T<:ACSet - run!(abm::ABM, RuntimeABM(abm, init; kw...); save, maxevent) + run!(abm::ABM, RuntimeABM(abm, init; kw...), Traj(init); save, maxevent) end -function run!(abm::ABM, rt::RuntimeABM, output::Union{Traj,Nothing}=nothing; +function run!(abm::ABM, rt::RuntimeABM, output::Traj; save=deepcopy, maxevent=MAXEVENT, maxtime=Inf) - output = isnothing(output) ? Traj(rt.state) : output + # Helper functions that automatically incorporate the runtime `rt` log!(rule::Int, sp::Span) = push!(output, rt.tnow, rule, save(rt.state), sp) - - disable!′(which) = disable!(rt.sampler, which, rt.tnow) - + disable!′(key::Pair) = disable!(rt.sampler, key, rt.tnow) + disable!′(i::Int) = disable!′(i => nothing) function enable!′(m::ACSetTransformation, rule::Int, key=nothing) haz = get_hazard(m, rt.tnow, abm.rules[rule].timer) enable!(rt.sampler, rule => key, haz, rt.tnow, rt.tnow, rt.rng) end + # Main loop while rt.nevent < maxevent && rt.tnow < maxtime - which = next(rt) # get next event, update time - isnothing(which) && return output # end b/c no more events! + # get next event + update clock time + which = next(rt) + + if isnothing(which) + @info "Stochastic scheduling algorithm ran out of events" + return output + end + + # Unpack data associated with the current event + event::Int, key::Maybe{KeyType} = which + rule::ABMRule, clocks::AbsHomSet = abm.rules[event], rt.clocks[event] + rule_type::Symbol = ruletype(rule) # DPO, SPO, etc. - event, key = which - rule, clock = abm.rules[event], rt.clocks[event] @debug "$(length(output)): event $event fired @ $(rt.tnow)" - m = get_match(pattern_type(rule), pattern(rule), rt.state, clock, key) - update_maps = rewrite_match_maps(getrule(abm.rules[event]), m) - rh, kh, kg = update_maps[:rh], update_maps[:kh], update_maps[:kg] - rt.state = codom(rh) # update runtime state - log!(event, Span(kg, kh)) # record state after event + # If RegularPattern, we have an explicit match, otherwise randomly pick one + m = get_match(pattern_type(rule), pattern(rule), rt.state, clocks, key) - if pattern_type(rule) == EmptyP() # "always enabled" need special treatment - disable!′(which) # their hom-set won't change, so clocks - enable!′(create(rt.state), event) # won't be reset, so do it manually - end + # Excute rewrite rule and unpack results + rw_result = (rule_type, rewrite_match_maps(getrule(abm.rules[event]), m)) + rmap::ACSetTransformation = get_rmap(rw_result...) + (lft, rght) = pmap = get_pmap(rw_result...) - # update matches for all events - for (t, (ruleₜ, clockₜ)) in enumerate(zip(abm.rules, rt.clocks)) - pt = pattern_type(ruleₜ) - if pt == RegularP() # update explicit hom-set - homs = clockₜ.val - for d in Incremental.deletion!(clockₜ, kg) - disable!′(t => d) # disable clocks which are invalidated + rt.state = codom(rmap) # update runtime state + log!(event, pmap) # record event result + + # update matches for all events + #------------------------------ + # The only time EmptyPattern rules update is when they are fired + if pattern_type(rule) == EmptyP() + disable!′(which) + enable!′(create(rt.state), event) + end + # All other rules can potentially update in response to the current event + for (i, (ruleᵢ, clocksᵢ)) in enumerate(zip(abm.rules, rt.clocks)) + pt = pattern_type(ruleᵢ) + if pt == RegularP() # update explicit hom-set w/r/t span Xₙ ↩ • -> Xₙ₊₁ + for d in deletion!(clocksᵢ, lft) + disable!′(i => d) # disable clocks which are invalidated end - for a in Incremental.addition!(clockₜ, event, rh, kh) - enable!′(clockₜ[a], t, a) + for a in addition!(clocksᵢ, event, rmap, rght) # rght: R → Xₙ₊₁ + enable!′(clocksᵢ[a], i, a) end elseif pt isa RepresentableP - if !all(o->all(is_isomorphic, [kg[o], kh[o]]), keys(pt)) - haskey(rt, t => nothing) && disable!′(t => nothing) - if all(>(0), nparts.(Ref(rt.state), collect(keys(pt)))) - enable!′(create(rt.state), t) + relevant_obs = keys(pt) + # we need to update current timer if # of parts has changed + if !all(ob -> allequal(nparts.(codom.(pmap), Ref(ob))), relevant_obs) + currently_enabled = haskey(rt, i) + currently_enabled && disable!′(i) # Disable if active + # enable new timer if possible to apply rule + if all(>(0), nparts.(Ref(rt.state), relevant_obs)) + enable!′(create(rt.state), i) end end end diff --git a/src/AlgebraicABMs.jl b/src/AlgebraicABMs.jl index 176d263..ebaac56 100644 --- a/src/AlgebraicABMs.jl +++ b/src/AlgebraicABMs.jl @@ -5,10 +5,10 @@ module AlgebraicABMs using Reexport +include("ABMs.jl") include("Distributions.jl") include("RewriteSemiMarkov.jl") include("PetriInterface.jl") -include("ABMs.jl") @reexport using .Distributions @reexport using .PetriInterface diff --git a/src/RewriteSemiMarkov.jl b/src/RewriteSemiMarkov.jl index 0373bf9..9f397cd 100644 --- a/src/RewriteSemiMarkov.jl +++ b/src/RewriteSemiMarkov.jl @@ -1,10 +1,9 @@ module RewriteSemiMarkov -export run! - using Catlab, AlgebraicRewriting, AlgebraicPetri using Random using Fleck +import ..ABMs: run! # -------------------------------------------------------------------------------- # we want something to store the rules, clocks associated to each, and their type diff --git a/test/ABMS.jl b/test/ABMS.jl index d31947d..6a6c77a 100644 --- a/test/ABMS.jl +++ b/test/ABMS.jl @@ -1,35 +1,48 @@ module TestABMs -# ENV["JULIA_DEBUG"] = "AlgebraicABMs" +ENV["JULIA_DEBUG"] = "AlgebraicABMs" # turn on @debug messages for this package using Test using AlgebraicABMs - using Catlab, AlgebraicRewriting -using AlgebraicABMs.ABMs: ABMRule, DiscreteHazard, ContinuousHazard, RegularP, - EmptyP, RepresentableP, RuntimeABM +using AlgebraicABMs.ABMs: RegularP, EmptyP, RepresentableP, RuntimeABM + +# L = ∅, I = ∅, R = •↺ +create_loop = ABMRule(Rule(id(Graph()), # l : I -> L + create(ob(terminal(Graph)))), # r : I → R + DiscreteHazard(1.)) # Dirac delta, indep. of clock time / state -create_vertex = ABMRule(Rule(id(Graph()), create(ob(terminal(Graph)))), DiscreteHazard(1.)) -@test create_vertex.pattern_type == EmptyP() +# check that we know this rule has an empty pattern L +@test create_loop.pattern_type == EmptyP() -add_loop = ABMRule(Rule(id(Graph(1)), delete(Graph(1))), DiscreteHazard(1.5)) -@test add_loop.pattern_type isa RepresentableP +# • ← • → •↺ +add_loop = ABMRule(Rule(id(Graph(1)), # + delete(Graph(1))), # r : I -> R + DiscreteHazard(1.5)) +@test add_loop.pattern_type == RepresentableP(Dict(:V=>[1])) +# •↺ ⇽ • → • rem_loop = ABMRule(Rule(delete(Graph(1)), id(Graph(1))), DiscreteHazard(2)) @test rem_loop.pattern_type == RegularP() -rem_edge = ABMRule(Rule(homomorphism(Graph(2), path_graph(Graph, 2); - monic=true), id(Graph(2))), ContinuousHazard(1)) -@test rem_edge.pattern_type isa RepresentableP - +# •→• ⇽ • → • +rem_edge = ABMRule(Rule(homomorphism(Graph(2), path_graph(Graph, 2); monic=true), + id(Graph(2))), + ContinuousHazard(1)) +@test rem_edge.pattern_type == RepresentableP(Dict(:E=>[1])) +# Create initial state G = @acset Graph begin V=3; E=3; src=[1,1,1]; tgt=[1,1,2] end -abm = ABM([create_vertex, add_loop, rem_loop, rem_edge]) +to_graphviz(G) -@test only(RuntimeABM(abm, G).clocks[3].val.match_vect)[1][:E](1) == 1 # One cached hom +# Assemble rules into ABM +abm = ABM([create_loop, add_loop, rem_loop, rem_edge]) +# 2 loops, so 2 cached homs for the only rule with an explicit hom set +@test length(only(RuntimeABM(abm, G).clocks[3].val.match_vect)) == 2 -traj = run!(abm, G); +traj = run!(abm, G; maxevent=10); +@test length(traj) == 10 end # module diff --git a/test/Project.toml b/test/Project.toml index aa9f405..9386d92 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,10 +3,8 @@ AlgebraicABMs = "5a5e3447-9604-46e6-8d91-cb86f5f51721" AlgebraicPetri = "4f99eebe-17bf-4e98-b6a1-2c4f205a959b" AlgebraicRewriting = "725a01d3-f174-5bbd-84e1-b9417bad95d9" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Fleck = "5bb9b785-358c-4fee-af0f-b94a146244a8" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"