Skip to content

Commit

Permalink
Use Nothing default for StructOfConstraints
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed May 6, 2021
1 parent bed957f commit 588b42b
Showing 1 changed file with 76 additions and 82 deletions.
158 changes: 76 additions & 82 deletions src/Utilities/struct_of_constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,36 @@ abstract type StructOfConstraints <: MOI.ModelLike end

function _throw_if_cannot_delete(model::StructOfConstraints, vis, fast_in_vis)
broadcastcall(model) do constrs
return _throw_if_cannot_delete(constrs, vis, fast_in_vis)
if constrs !== nothing
_throw_if_cannot_delete(constrs, vis, fast_in_vis)
end
return
end
return
end
function _deleted_constraints(
callback::Function,
model::StructOfConstraints,
vi,
)
broadcastcall(model) do constrs
return _deleted_constraints(callback, constrs, vi)
if constrs !== nothing
_deleted_constraints(callback, constrs, vi)
end
return
end
return
end

function MOI.add_constraint(
model::StructOfConstraints,
func::MOI.AbstractFunction,
set::MOI.AbstractSet,
)
if MOI.supports_constraint(model, typeof(func), typeof(set))
return MOI.add_constraint(
constraints(model, typeof(func), typeof(set)),
func,
set,
)
else
throw(MOI.UnsupportedConstraint{typeof(func),typeof(set)}())
func::F,
set::S,
) where {F<:MOI.AbstractFunction,S<:MOI.AbstractSet}
if !MOI.supports_constraint(model, F, S)
throw(MOI.UnsupportedConstraint{F,S}())
end
return MOI.add_constraint(constraints(model, F, S), func, set)
end

function constraints(
Expand All @@ -49,26 +52,27 @@ function MOI.get(
end

function MOI.delete(model::StructOfConstraints, ci::MOI.ConstraintIndex)
return MOI.delete(constraints(model, ci), ci)
MOI.delete(constraints(model, ci), ci)
return
end

function MOI.is_valid(
model::StructOfConstraints,
ci::MOI.ConstraintIndex{F,S},
) where {F,S}
if MOI.supports_constraint(model, F, S)
return MOI.is_valid(constraints(model, ci), ci)
else
if !MOI.supports_constraint(model, F, S)
return false
end
return MOI.is_valid(constraints(model, ci), ci)
end

function MOI.modify(
model::StructOfConstraints,
ci::MOI.ConstraintIndex,
change::MOI.AbstractFunctionModification,
)
return MOI.modify(constraints(model, ci), ci, change)
MOI.modify(constraints(model, ci), ci, change)
return
end

function MOI.set(
Expand All @@ -77,45 +81,55 @@ function MOI.set(
ci::MOI.ConstraintIndex,
func_or_set,
)
return MOI.set(constraints(model, ci), attr, ci, func_or_set)
MOI.set(constraints(model, ci), attr, ci, func_or_set)
return
end

function MOI.get(
model::StructOfConstraints,
loc::MOI.ListOfConstraintTypesPresent,
) where {T}
return broadcastvcat(model) do v
return MOI.get(v, loc)
attr::MOI.ListOfConstraintTypesPresent,
)
return broadcastvcat(model) do constrs
if constrs === nothing
return Tuple{DataType,DataType}[]
end
return MOI.get(constrs, attr)
end
end

function MOI.get(
model::StructOfConstraints,
noc::MOI.NumberOfConstraints{F,S},
attr::MOI.NumberOfConstraints{F,S},
) where {F,S}
if MOI.supports_constraint(model, F, S)
return MOI.get(constraints(model, F, S), noc)
else
if !MOI.supports_constraint(model, F, S)
return 0
end
return MOI.get(constraints(model, F, S), attr)
end

function MOI.get(
model::StructOfConstraints,
loc::MOI.ListOfConstraintIndices{F,S},
attr::MOI.ListOfConstraintIndices{F,S},
) where {F,S}
if MOI.supports_constraint(model, F, S)
return MOI.get(constraints(model, F, S), loc)
else
if !MOI.supports_constraint(model, F, S)
return MOI.ConstraintIndex{F,S}[]
end
return MOI.get(constraints(model, F, S), attr)
end

function MOI.is_empty(model::StructOfConstraints)
return mapreduce_constraints(MOI.is_empty, &, model, true)
return mapreduce_constraints(&, model, true) do constrs
return constrs === nothing || MOI.is_empty(constrs)
end
end
function MOI.empty!(model::StructOfConstraints)
return broadcastcall(MOI.empty!, model)
broadcastcall(model) do constrs
if constrs !== nothing
MOI.empty!(constrs)
end
return
end
return
end

# Can be used to access constraints of a model
Expand Down Expand Up @@ -178,33 +192,13 @@ end

# (MOI, :Zeros) -> :(MOI.Zeros)
# (:Zeros) -> :(MOI.Zeros)
_set(s::SymbolSet) = esc(s.s)
_fun(s::SymbolFun) = esc(s.s)
function _typedset(s::SymbolSet)
if s.typed
T = esc(:T)
:($(_set(s)){$T})
else
_set(s)
end
end
function _typedfun(s::SymbolFun)
if s.typed
T = esc(:T)
:($(_fun(s)){$T})
else
_fun(s)
end
end
_typedset(s::SymbolSet) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s)
_typedfun(s::SymbolFun) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s)

# Base.lowercase is moved to Unicode.lowercase in Julia v0.7
using Unicode

_field(s::SymbolFS) = Symbol(replace(lowercase(string(s.s)), "." => "_"))

_getC(s::SymbolSet) = :(VectorOfConstraints{F,$(_typedset(s))})
_getC(s::SymbolFun) = _typedfun(s)

_callfield(f, s::SymbolFS) = :($f(model.$(_field(s))))
_broadcastfield(b, s::SymbolFS) = :($b(f, model.$(_field(s))))
_mapreduce_field(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s)))))
Expand All @@ -223,62 +217,61 @@ If `types` is vector of `SymbolFun` (resp. `SymbolSet`) then the constraints
of that function (resp. set) type are stored in the corresponding field.
"""
function struct_of_constraint_code(struct_name, types, field_types = nothing)
esc_struct_name = struct_name
T = esc(:T)
typed_struct = :($(esc_struct_name){$T})
typed_struct = :($(struct_name){$T})
type_parametrized = field_types === nothing
if type_parametrized
field_types = [Symbol("C$i") for i in eachindex(types)]
append!(typed_struct.args, field_types)
end
struct_def = :(mutable struct $typed_struct <: StructOfConstraints end)

struct_def = :(struct $typed_struct <: StructOfConstraints end)

for (t, field_type) in zip(types, field_types)
field = _field(t)
push!(struct_def.args[3].args, :($field::$field_type))
end
code = quote
function $MOIU.broadcastcall(f::Function, model::$esc_struct_name)
return $(Expr(:block, _callfield.(Ref(:f), types)...))
function $MOIU.broadcastcall(f::Function, model::$struct_name)
$(Expr(:block, _callfield.(Ref(:f), types)...))
return
end
function $MOIU.broadcastvcat(f::Function, model::$esc_struct_name)

function $MOIU.broadcastvcat(f::Function, model::$struct_name)
return vcat($(_callfield.(Ref(:f), types)...))
end

function $MOIU.mapreduce_constraints(
f::Function,
op::Function,
model::$esc_struct_name,
model::$struct_name,
cur,
)
return $(Expr(:block, _mapreduce_field.(types)...))
end
end

for t in types
if t isa SymbolFun
fun = _fun(t)
set = :(MOI.AbstractSet)
else
fun = :(MOI.AbstractFunction)
set = _set(t)
end
for (t, field_type) in zip(types, field_types)
field = _field(t)
code = quote
$code
push!(struct_def.args[3].args, :($field::Union{Nothing,$field_type}))
fun = t isa SymbolFun ? esc(t.s) : :(MOI.AbstractFunction)
set = t isa SymbolFun ? :(MOI.AbstractSet) : esc(t.s)
constraints_code = :(
function $MOIU.constraints(
model::$esc_struct_name,
model::$typed_struct,
::Type{<:$fun},
::Type{<:$set},
)
) where {$T}
if model.$field === nothing
model.$field = $(field_type)()
end
return model.$field
end
)
if type_parametrized
append!(constraints_code.args[1].args, field_types)
end
push!(code.args, constraints_code)
end
supports_code = if eltype(types) <: SymbolFun
quote
function $MOI.supports_constraint(
model::$esc_struct_name{$T},
model::$struct_name{$T},
::Type{F},
::Type{S},
) where {
Expand All @@ -293,7 +286,7 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)
@assert eltype(types) <: SymbolSet
quote
function $MOI.supports_constraint(
model::$esc_struct_name{$T},
model::$struct_name{$T},
::Type{F},
::Type{S},
) where {
Expand All @@ -307,7 +300,8 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)
end
expr = Expr(:block, struct_def, supports_code, code)
if !isempty(field_types)
constructors = [:($field_type()) for field_type in field_types]
constructors = [:(nothing) for field_type in field_types]
# constructors = [:($field_type()) for field_type in field_types]
# If there is no field type, the default constructor is sufficient and
# adding this constructor will make a `StackOverflow`.
constructor_code = :(function $typed_struct() where {$T}
Expand Down

0 comments on commit 588b42b

Please sign in to comment.