Skip to content

Commit

Permalink
Merge pull request #1122 from vyudu/dsl-no-infer
Browse files Browse the repository at this point in the history
Add @no_infer flag for turning off species/variable/parameter inferring
  • Loading branch information
isaacsas authored Nov 20, 2024
2 parents ff149fd + c710814 commit ecd28d6
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 17 deletions.
61 changes: 44 additions & 17 deletions src/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const pure_rate_arrows = Set{Symbol}([:(=>), :(<=), :⇐, :⟽, :⇒, :⟾, :⇔
# Declares the keys used for various options.
const option_keys = (:species, :parameters, :variables, :ivs, :compounds, :observables,
:default_noise_scaling, :differentials, :equations,
:continuous_events, :discrete_events, :combinatoric_ratelaws)
:continuous_events, :discrete_events, :combinatoric_ratelaws, :require_declaration)

### `@species` Macro ###

Expand Down Expand Up @@ -220,13 +220,14 @@ struct ReactionStruct
products::Vector{ReactantStruct}
rate::ExprValues
metadata::Expr
rxexpr::Expr

function ReactionStruct(sub_line::ExprValues, prod_line::ExprValues, rate::ExprValues,
metadata_line::ExprValues)
metadata_line::ExprValues, rx_line::Expr)
sub = recursive_find_reactants!(sub_line, 1, Vector{ReactantStruct}(undef, 0))
prod = recursive_find_reactants!(prod_line, 1, Vector{ReactantStruct}(undef, 0))
metadata = extract_metadata(metadata_line)
new(sub, prod, rate, metadata)
new(sub, prod, rate, metadata, rx_line)
end
end

Expand Down Expand Up @@ -283,6 +284,17 @@ function extract_metadata(metadata_line::Expr)
return metadata
end



struct UndeclaredSymbolicError <: Exception
msg::String
end

function Base.showerror(io::IO, err::UndeclaredSymbolicError)
print(io, "UndeclaredSymbolicError: ")
print(io, err.msg)
end

### DSL Internal Master Function ###

# Function for creating a ReactionSystem structure (used by the @reaction_network macro).
Expand All @@ -308,6 +320,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
compound_expr, compound_species = read_compound_options(options)
continuous_events_expr = read_events_option(options, :continuous_events)
discrete_events_expr = read_events_option(options, :discrete_events)
requiredec = haskey(options, :require_declaration)

# Parses reactions, species, and parameters.
reactions = get_reactions(reaction_lines)
Expand All @@ -317,7 +330,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))

# Reads equations.
vars_extracted, add_default_diff, equations = read_equations_options(
options, variables_declared)
options, variables_declared; requiredec)
variables = vcat(variables_declared, vars_extracted)

# Handle independent variables
Expand All @@ -341,13 +354,13 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))

# Reads observables.
observed_vars, observed_eqs, obs_syms = read_observed_options(
options, [species_declared; variables], all_ivs)
options, [species_declared; variables], all_ivs; requiredec)

# Collect species and parameters, including ones inferred from the reactions.
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
variables)))
species_extracted, parameters_extracted = extract_species_and_parameters!(
reactions, declared_syms)
reactions, declared_syms; requiredec)

species = vcat(species_declared, species_extracted)
parameters = vcat(parameters_declared, parameters_extracted)
Expand Down Expand Up @@ -425,15 +438,15 @@ function get_reactions(exprs::Vector{Expr}, reactions = Vector{ReactionStruct}(u
error("Error: Must provide a tuple of reaction rates when declaring a bi-directional reaction.")
end
push_reactions!(reactions, reaction.args[2], reaction.args[3],
rate.args[1], metadata.args[1], arrow)
rate.args[1], metadata.args[1], arrow, line)
push_reactions!(reactions, reaction.args[3], reaction.args[2],
rate.args[2], metadata.args[2], arrow)
rate.args[2], metadata.args[2], arrow, line)
elseif in(arrow, fwd_arrows)
push_reactions!(reactions, reaction.args[2], reaction.args[3],
rate, metadata, arrow)
rate, metadata, arrow, line)
elseif in(arrow, bwd_arrows)
push_reactions!(reactions, reaction.args[3], reaction.args[2],
rate, metadata, arrow)
rate, metadata, arrow, line)
else
throw("Malformed reaction, invalid arrow type used in: $(MacroTools.striplines(line))")
end
Expand Down Expand Up @@ -467,7 +480,7 @@ end
# Takes a reaction line and creates reaction(s) from it and pushes those to the reaction array.
# Used to create multiple reactions from, for instance, `k, (X,Y) --> 0`.
function push_reactions!(reactions::Vector{ReactionStruct}, sub_line::ExprValues,
prod_line::ExprValues, rate::ExprValues, metadata::ExprValues, arrow::Symbol)
prod_line::ExprValues, rate::ExprValues, metadata::ExprValues, arrow::Symbol, line::Expr)
# The rates, substrates, products, and metadata may be in a tupple form (e.g. `k, (X,Y) --> 0`).
# This finds the length of these tuples (or 1 if not in tuple forms). Errors if lengs inconsistent.
lengs = (tup_leng(sub_line), tup_leng(prod_line), tup_leng(rate), tup_leng(metadata))
Expand All @@ -490,7 +503,7 @@ function push_reactions!(reactions::Vector{ReactionStruct}, sub_line::ExprValues

push!(reactions,
ReactionStruct(get_tup_arg(sub_line, i),
get_tup_arg(prod_line, i), get_tup_arg(rate, i), metadata_i))
get_tup_arg(prod_line, i), get_tup_arg(rate, i), metadata_i, line))
end
end

Expand All @@ -511,20 +524,26 @@ end

# Function looping through all reactions, to find undeclared symbols (species or
# parameters), and assign them to the right category.
function extract_species_and_parameters!(reactions, excluded_syms)
function extract_species_and_parameters!(reactions, excluded_syms; requiredec = false)
species = OrderedSet{Union{Symbol, Expr}}()
for reaction in reactions
for reactant in Iterators.flatten((reaction.substrates, reaction.products))
add_syms_from_expr!(species, reactant.reactant, excluded_syms)
(!isempty(species) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized variables $(join(species, ", ")) detected in reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all species must be explicitly declared with the @species macro."))
end
end

foreach(s -> push!(excluded_syms, s), species)
parameters = OrderedSet{Union{Symbol, Expr}}()
for reaction in reactions
add_syms_from_expr!(parameters, reaction.rate, excluded_syms)
(!isempty(parameters) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized parameter $(join(parameters, ", ")) detected in rate expression: $(reaction.rate) for the following reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all parameters must be explicitly declared with the @parameters macro."))
for reactant in Iterators.flatten((reaction.substrates, reaction.products))
add_syms_from_expr!(parameters, reactant.stoichiometry, excluded_syms)
(!isempty(parameters) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized parameters $(join(parameters, ", ")) detected in the stoichiometry for reactant $(reactant.reactant) in the following reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all parameters must be explicitly declared with the @parameters macro."))
end
end

Expand Down Expand Up @@ -682,7 +701,7 @@ end
# `vars_extracted`: A vector with extracted variables (lhs in pure differential equations only).
# `dtexpr`: If a differential equation is defined, the default derivative (D ~ Differential(t)) must be defined.
# `equations`: a vector with the equations provided.
function read_equations_options(options, variables_declared)
function read_equations_options(options, variables_declared; requiredec = false)
# Prepares the equations. First, extracts equations from provided option (converting to block form if required).
# Next, uses MTK's `parse_equations!` function to split input into a vector with the equations.
eqs_input = haskey(options, :equations) ? options[:equations].args[3] : :(begin end)
Expand Down Expand Up @@ -711,9 +730,13 @@ function read_equations_options(options, variables_declared)
diff_var = lhs.args[2]
if in(diff_var, forbidden_symbols_error)
error("A forbidden symbol ($(diff_var)) was used as an variable in this differential equation: $eq")
elseif (!in(diff_var, variables_declared)) && requiredec
throw(UndeclaredSymbolicError(
"Unrecognized symbol $(diff_var) was used as a variable in an equation: \"$eq\". Since the @require_declaration flag is set, all variables in equations must be explicitly declared via @variables, @species, or @parameters."))
else
add_default_diff = true
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
end
add_default_diff = true
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
end
end

Expand Down Expand Up @@ -752,7 +775,7 @@ function create_differential_expr(options, add_default_diff, used_syms, tiv)
end

# Reads the observables options. Outputs an expression ofr creating the observable variables, and a vector of observable equations.
function read_observed_options(options, species_n_vars_declared, ivs_sorted)
function read_observed_options(options, species_n_vars_declared, ivs_sorted; requiredec = false)
if haskey(options, :observables)
# Gets list of observable equations and prepares variable declaration expression.
# (`options[:observables]` includes `@observables`, `.args[3]` removes this part)
Expand All @@ -763,6 +786,10 @@ function read_observed_options(options, species_n_vars_declared, ivs_sorted)
for (idx, obs_eq) in enumerate(observed_eqs.args)
# Extract the observable, checks errors, and continues the loop if the observable has been declared.
obs_name, ivs, defaults, metadata = find_varinfo_in_declaration(obs_eq.args[2])
if (requiredec && !in(obs_name, species_n_vars_declared))
throw(UndeclaredSymbolicError(
"An undeclared variable ($obs_name) was declared as an observable in the following observable equation: \"$obs_eq\". Since the flag @require_declaration is set, all variables must be declared with the @species, @parameters, or @variables macros."))
end
isempty(ivs) ||
error("An observable ($obs_name) was given independent variable(s). These should not be given, as they are inferred automatically.")
isnothing(defaults) ||
Expand Down
70 changes: 70 additions & 0 deletions test/dsl/dsl_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1022,3 +1022,73 @@ let
@parameters v n
@test isequal(Catalyst.expand_registered_functions(equations(rn4)[1]), D(A) ~ v*(A^n))
end

### test that @no_infer properly throws errors when undeclared variables are written

import Catalyst: UndeclaredSymbolicError
let
# Test error when species are inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@parameters k
k, A --> B
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@species A(t) B(t)
@parameters k
k, A --> B
end

# Test error when a parameter in rate is inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@species A(t) B(t)
@parameters k
k*n, A --> B
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@parameters n k
@species A(t) B(t)
k*n, A --> B
end

# Test error when a parameter in stoichiometry is inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@parameters k
@species A(t) B(t)
k, n*A --> B
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@parameters k n
@species A(t) B(t)
k, n*A --> B
end

# Test error when a variable in an equation is inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@equations D(V) ~ V^2
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@variables V(t)
@equations D(V) ~ V^2
end

# Test error when a variable in an observable is inferred
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
@require_declaration
@variables X1(t)
@observables X2 ~ X1
end
@test_nowarn @macroexpand @reaction_network begin
@require_declaration
@variables X1(t) X2(t)
@observables X2 ~ X1
end
end

0 comments on commit ecd28d6

Please sign in to comment.