Skip to content

Commit

Permalink
Simplify the macro code
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed May 6, 2021
1 parent fef8c86 commit 3aec60c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/Utilities/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ macro model(
sets = vector_sets
end
voc = map(sets) do set
return :(VectorOfConstraints{$(_typedfun(funs[i])),$(_typedset(set))})
return :(VectorOfConstraints{$(_typed(funs[i])),$(_typed(set))})
end
return _struct_of_constraints_type(cname, voc, true)
end
Expand Down
102 changes: 38 additions & 64 deletions src/Utilities/struct_of_constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ function _throw_if_cannot_delete(model::StructOfConstraints, vis, fast_in_vis)
end
return
end

function _deleted_constraints(
callback::Function,
model::StructOfConstraints,
Expand Down Expand Up @@ -43,6 +44,7 @@ function constraints(
end
return constraints(model, F, S)
end

function MOI.get(
model::StructOfConstraints,
attr::Union{MOI.ConstraintFunction,MOI.ConstraintSet},
Expand Down Expand Up @@ -122,6 +124,7 @@ function MOI.is_empty(model::StructOfConstraints)
return constrs === nothing || MOI.is_empty(constrs)
end
end

function MOI.empty!(model::StructOfConstraints)
broadcastcall(model) do constrs
if constrs !== nothing
Expand All @@ -134,58 +137,48 @@ end

# Can be used to access constraints of a model
"""
broadcastcall(f::Function, model::AbstractModel)
Calls `f(contrs)` for every vector `constrs::Vector{ConstraintIndex{F, S}, F, S}` of the model.
broadcastcall(f::Function, model::StructOfConstraints)
# Examples
To add all constraints of the model to a solver `solver`, one can do
```julia
_addcon(solver, ci, f, s) = MOI.add_constraint(solver, f, s)
function _addcon(solver, constrs::Vector)
for constr in constrs
_addcon(solver, constr...)
end
end
MOIU.broadcastcall(constrs -> _addcon(solver, constrs), model)
```
Calls `f(contrs)` for every field in `model`.
"""
function broadcastcall end

"""
broadcastvcat(f::Function, model::AbstractModel)
broadcastvcat(f::Function, model::StructOfConstraints)
Calls `f(contrs)` for every vector `constrs::Vector{ConstraintIndex{F, S}, F, S}` of the model and concatenate the results with `vcat` (this is used internally for `ListOfConstraintTypesPresent`).
# Examples
To get the list of all functions:
```julia
_getfun(ci, f, s) = f
_getfun(cindices::Tuple) = _getfun(cindices...)
_getfuns(constrs::Vector) = _getfun.(constrs)
MOIU.broadcastvcat(_getfuns, model)
```
Calls `f(contrs)` for every field in `model` and `vcat`s the results.
"""
function broadcastvcat end

"""
mapreduce_constraints(
f::Function,
op::Function,
model::StructOfConstraints,
init,
)
Call `mapreduce` on every field of `model` given an initial value `init`. Each
element in the map is computed as `f(x)` and the elements are reduced using
`op`.
"""
function mapreduce_constraints end

# Macro code

abstract type SymbolFS end

struct SymbolFun <: SymbolFS
s::Union{Symbol,Expr}
typed::Bool
end

struct SymbolSet <: SymbolFS
s::Union{Symbol,Expr}
typed::Bool
end

_typedset(s::SymbolSet) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s)
_typedfun(s::SymbolFun) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s)
_typed(s::SymbolFS) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s)

# Base.lowercase is moved to Unicode.lowercase in Julia v0.7
import Unicode
Expand All @@ -195,9 +188,8 @@ function _field(s::SymbolFS)
end

_callfield(f, s::SymbolFS) = :($f(model.$(_field(s))))
_broadcastfield(b, s::SymbolFS) = :($b(f, model.$(_field(s))))

_mapreduce_field(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s)))))
_mapreduce_constraints(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s)))))

"""
struct_of_constraint_code(struct_name, types, field_types = nothing)
Expand All @@ -219,9 +211,9 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)
field_types = [Symbol("C$i") for i in eachindex(types)]
append!(typed_struct.args, field_types)
end
struct_def = :(mutable struct $typed_struct <: StructOfConstraints end)

code = quote
mutable struct $typed_struct <: StructOfConstraints end

function $MOIU.broadcastcall(f::Function, model::$struct_name)
$(Expr(:block, _callfield.(Ref(:f), types)...))
return
Expand All @@ -243,7 +235,7 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)

for (t, field_type) in zip(types, field_types)
field = _field(t)
push!(struct_def.args[3].args, :($field::Union{Nothing,$field_type}))
push!(code.args[2].args[3].args, :($field::Union{Nothing,$field_type}))
fun = t isa SymbolFun ? esc(t.s) : :(MOI.AbstractFunction)
set = t isa SymbolFun ? :(MOI.AbstractSet) : esc(t.s)
constraints_code = :(
Expand All @@ -263,49 +255,31 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)
end
push!(code.args, constraints_code)
end
supports_code = if eltype(types) <: SymbolFun
quote
function $MOI.supports_constraint(
model::$struct_name{$T},
::Type{F},
::Type{S},
) where {
$T,
F<:Union{$(_typedfun.(types)...)},
S<:MOI.AbstractSet,
}
return $MOI.supports_constraint(constraints(model, F, S), F, S)
end
end
else
@assert eltype(types) <: SymbolSet
quote
is_func = eltype(types) <: SymbolFun
SuperF = is_func ? :(Union{$(_typed.(types)...)}) : :(MOI.AbstractFunction)
SuperS = is_func ? :(MOI.AbstractSet) : :(Union{$(_typed.(types)...)})
push!(
code.args,
:(
function $MOI.supports_constraint(
model::$struct_name{$T},
::Type{F},
::Type{S},
) where {
$T,
F<:MOI.AbstractFunction,
S<:Union{$(_typedset.(types)...)},
}
) where {$T,F<:$SuperF,S<:$SuperS}
return $MOI.supports_constraint(constraints(model, F, S), F, S)
end
end
end
expr = Expr(:block, struct_def, supports_code, code)
),
)
if !isempty(field_types)
constructors = [:(nothing) for field_type in field_types]
# constructors = [:($field_type()) for field_type in field_types]
# If there is no field type, the default constructor is sufficient and
# adding this constructor will make a `StackOverflow`.
constructor_code = :(function $typed_struct() where {$T}
return $typed_struct($(constructors...))
return $typed_struct($([:(nothing) for _ in field_types]...))
end)
if type_parametrized
append!(constructor_code.args[1].args, field_types)
end
push!(expr.args, constructor_code)
push!(code.args, constructor_code)
end
return expr
return code
end

0 comments on commit 3aec60c

Please sign in to comment.