Skip to content

Commit 588b42b

Browse files
committed
Use Nothing default for StructOfConstraints
1 parent bed957f commit 588b42b

File tree

1 file changed

+76
-82
lines changed

1 file changed

+76
-82
lines changed

src/Utilities/struct_of_constraints.jl

Lines changed: 76 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,36 @@ abstract type StructOfConstraints <: MOI.ModelLike end
22

33
function _throw_if_cannot_delete(model::StructOfConstraints, vis, fast_in_vis)
44
broadcastcall(model) do constrs
5-
return _throw_if_cannot_delete(constrs, vis, fast_in_vis)
5+
if constrs !== nothing
6+
_throw_if_cannot_delete(constrs, vis, fast_in_vis)
7+
end
8+
return
69
end
10+
return
711
end
812
function _deleted_constraints(
913
callback::Function,
1014
model::StructOfConstraints,
1115
vi,
1216
)
1317
broadcastcall(model) do constrs
14-
return _deleted_constraints(callback, constrs, vi)
18+
if constrs !== nothing
19+
_deleted_constraints(callback, constrs, vi)
20+
end
21+
return
1522
end
23+
return
1624
end
1725

1826
function MOI.add_constraint(
1927
model::StructOfConstraints,
20-
func::MOI.AbstractFunction,
21-
set::MOI.AbstractSet,
22-
)
23-
if MOI.supports_constraint(model, typeof(func), typeof(set))
24-
return MOI.add_constraint(
25-
constraints(model, typeof(func), typeof(set)),
26-
func,
27-
set,
28-
)
29-
else
30-
throw(MOI.UnsupportedConstraint{typeof(func),typeof(set)}())
28+
func::F,
29+
set::S,
30+
) where {F<:MOI.AbstractFunction,S<:MOI.AbstractSet}
31+
if !MOI.supports_constraint(model, F, S)
32+
throw(MOI.UnsupportedConstraint{F,S}())
3133
end
34+
return MOI.add_constraint(constraints(model, F, S), func, set)
3235
end
3336

3437
function constraints(
@@ -49,26 +52,27 @@ function MOI.get(
4952
end
5053

5154
function MOI.delete(model::StructOfConstraints, ci::MOI.ConstraintIndex)
52-
return MOI.delete(constraints(model, ci), ci)
55+
MOI.delete(constraints(model, ci), ci)
56+
return
5357
end
5458

5559
function MOI.is_valid(
5660
model::StructOfConstraints,
5761
ci::MOI.ConstraintIndex{F,S},
5862
) where {F,S}
59-
if MOI.supports_constraint(model, F, S)
60-
return MOI.is_valid(constraints(model, ci), ci)
61-
else
63+
if !MOI.supports_constraint(model, F, S)
6264
return false
6365
end
66+
return MOI.is_valid(constraints(model, ci), ci)
6467
end
6568

6669
function MOI.modify(
6770
model::StructOfConstraints,
6871
ci::MOI.ConstraintIndex,
6972
change::MOI.AbstractFunctionModification,
7073
)
71-
return MOI.modify(constraints(model, ci), ci, change)
74+
MOI.modify(constraints(model, ci), ci, change)
75+
return
7276
end
7377

7478
function MOI.set(
@@ -77,45 +81,55 @@ function MOI.set(
7781
ci::MOI.ConstraintIndex,
7882
func_or_set,
7983
)
80-
return MOI.set(constraints(model, ci), attr, ci, func_or_set)
84+
MOI.set(constraints(model, ci), attr, ci, func_or_set)
85+
return
8186
end
8287

8388
function MOI.get(
8489
model::StructOfConstraints,
85-
loc::MOI.ListOfConstraintTypesPresent,
86-
) where {T}
87-
return broadcastvcat(model) do v
88-
return MOI.get(v, loc)
90+
attr::MOI.ListOfConstraintTypesPresent,
91+
)
92+
return broadcastvcat(model) do constrs
93+
if constrs === nothing
94+
return Tuple{DataType,DataType}[]
95+
end
96+
return MOI.get(constrs, attr)
8997
end
9098
end
9199

92100
function MOI.get(
93101
model::StructOfConstraints,
94-
noc::MOI.NumberOfConstraints{F,S},
102+
attr::MOI.NumberOfConstraints{F,S},
95103
) where {F,S}
96-
if MOI.supports_constraint(model, F, S)
97-
return MOI.get(constraints(model, F, S), noc)
98-
else
104+
if !MOI.supports_constraint(model, F, S)
99105
return 0
100106
end
107+
return MOI.get(constraints(model, F, S), attr)
101108
end
102109

103110
function MOI.get(
104111
model::StructOfConstraints,
105-
loc::MOI.ListOfConstraintIndices{F,S},
112+
attr::MOI.ListOfConstraintIndices{F,S},
106113
) where {F,S}
107-
if MOI.supports_constraint(model, F, S)
108-
return MOI.get(constraints(model, F, S), loc)
109-
else
114+
if !MOI.supports_constraint(model, F, S)
110115
return MOI.ConstraintIndex{F,S}[]
111116
end
117+
return MOI.get(constraints(model, F, S), attr)
112118
end
113119

114120
function MOI.is_empty(model::StructOfConstraints)
115-
return mapreduce_constraints(MOI.is_empty, &, model, true)
121+
return mapreduce_constraints(&, model, true) do constrs
122+
return constrs === nothing || MOI.is_empty(constrs)
123+
end
116124
end
117125
function MOI.empty!(model::StructOfConstraints)
118-
return broadcastcall(MOI.empty!, model)
126+
broadcastcall(model) do constrs
127+
if constrs !== nothing
128+
MOI.empty!(constrs)
129+
end
130+
return
131+
end
132+
return
119133
end
120134

121135
# Can be used to access constraints of a model
@@ -178,33 +192,13 @@ end
178192

179193
# (MOI, :Zeros) -> :(MOI.Zeros)
180194
# (:Zeros) -> :(MOI.Zeros)
181-
_set(s::SymbolSet) = esc(s.s)
182-
_fun(s::SymbolFun) = esc(s.s)
183-
function _typedset(s::SymbolSet)
184-
if s.typed
185-
T = esc(:T)
186-
:($(_set(s)){$T})
187-
else
188-
_set(s)
189-
end
190-
end
191-
function _typedfun(s::SymbolFun)
192-
if s.typed
193-
T = esc(:T)
194-
:($(_fun(s)){$T})
195-
else
196-
_fun(s)
197-
end
198-
end
195+
_typedset(s::SymbolSet) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s)
196+
_typedfun(s::SymbolFun) = s.typed ? Expr(:curly, esc(s.s), esc(:T)) : esc(s.s)
199197

200198
# Base.lowercase is moved to Unicode.lowercase in Julia v0.7
201199
using Unicode
202200

203201
_field(s::SymbolFS) = Symbol(replace(lowercase(string(s.s)), "." => "_"))
204-
205-
_getC(s::SymbolSet) = :(VectorOfConstraints{F,$(_typedset(s))})
206-
_getC(s::SymbolFun) = _typedfun(s)
207-
208202
_callfield(f, s::SymbolFS) = :($f(model.$(_field(s))))
209203
_broadcastfield(b, s::SymbolFS) = :($b(f, model.$(_field(s))))
210204
_mapreduce_field(s::SymbolFS) = :(cur = op(cur, f(model.$(_field(s)))))
@@ -223,62 +217,61 @@ If `types` is vector of `SymbolFun` (resp. `SymbolSet`) then the constraints
223217
of that function (resp. set) type are stored in the corresponding field.
224218
"""
225219
function struct_of_constraint_code(struct_name, types, field_types = nothing)
226-
esc_struct_name = struct_name
227220
T = esc(:T)
228-
typed_struct = :($(esc_struct_name){$T})
221+
typed_struct = :($(struct_name){$T})
229222
type_parametrized = field_types === nothing
230223
if type_parametrized
231224
field_types = [Symbol("C$i") for i in eachindex(types)]
232225
append!(typed_struct.args, field_types)
233226
end
227+
struct_def = :(mutable struct $typed_struct <: StructOfConstraints end)
234228

235-
struct_def = :(struct $typed_struct <: StructOfConstraints end)
236-
237-
for (t, field_type) in zip(types, field_types)
238-
field = _field(t)
239-
push!(struct_def.args[3].args, :($field::$field_type))
240-
end
241229
code = quote
242-
function $MOIU.broadcastcall(f::Function, model::$esc_struct_name)
243-
return $(Expr(:block, _callfield.(Ref(:f), types)...))
230+
function $MOIU.broadcastcall(f::Function, model::$struct_name)
231+
$(Expr(:block, _callfield.(Ref(:f), types)...))
232+
return
244233
end
245-
function $MOIU.broadcastvcat(f::Function, model::$esc_struct_name)
234+
235+
function $MOIU.broadcastvcat(f::Function, model::$struct_name)
246236
return vcat($(_callfield.(Ref(:f), types)...))
247237
end
238+
248239
function $MOIU.mapreduce_constraints(
249240
f::Function,
250241
op::Function,
251-
model::$esc_struct_name,
242+
model::$struct_name,
252243
cur,
253244
)
254245
return $(Expr(:block, _mapreduce_field.(types)...))
255246
end
256247
end
257248

258-
for t in types
259-
if t isa SymbolFun
260-
fun = _fun(t)
261-
set = :(MOI.AbstractSet)
262-
else
263-
fun = :(MOI.AbstractFunction)
264-
set = _set(t)
265-
end
249+
for (t, field_type) in zip(types, field_types)
266250
field = _field(t)
267-
code = quote
268-
$code
251+
push!(struct_def.args[3].args, :($field::Union{Nothing,$field_type}))
252+
fun = t isa SymbolFun ? esc(t.s) : :(MOI.AbstractFunction)
253+
set = t isa SymbolFun ? :(MOI.AbstractSet) : esc(t.s)
254+
constraints_code = :(
269255
function $MOIU.constraints(
270-
model::$esc_struct_name,
256+
model::$typed_struct,
271257
::Type{<:$fun},
272258
::Type{<:$set},
273-
)
259+
) where {$T}
260+
if model.$field === nothing
261+
model.$field = $(field_type)()
262+
end
274263
return model.$field
275264
end
265+
)
266+
if type_parametrized
267+
append!(constraints_code.args[1].args, field_types)
276268
end
269+
push!(code.args, constraints_code)
277270
end
278271
supports_code = if eltype(types) <: SymbolFun
279272
quote
280273
function $MOI.supports_constraint(
281-
model::$esc_struct_name{$T},
274+
model::$struct_name{$T},
282275
::Type{F},
283276
::Type{S},
284277
) where {
@@ -293,7 +286,7 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)
293286
@assert eltype(types) <: SymbolSet
294287
quote
295288
function $MOI.supports_constraint(
296-
model::$esc_struct_name{$T},
289+
model::$struct_name{$T},
297290
::Type{F},
298291
::Type{S},
299292
) where {
@@ -307,7 +300,8 @@ function struct_of_constraint_code(struct_name, types, field_types = nothing)
307300
end
308301
expr = Expr(:block, struct_def, supports_code, code)
309302
if !isempty(field_types)
310-
constructors = [:($field_type()) for field_type in field_types]
303+
constructors = [:(nothing) for field_type in field_types]
304+
# constructors = [:($field_type()) for field_type in field_types]
311305
# If there is no field type, the default constructor is sufficient and
312306
# adding this constructor will make a `StackOverflow`.
313307
constructor_code = :(function $typed_struct() where {$T}

0 commit comments

Comments
 (0)