Skip to content

Commit

Permalink
improve performance of internal semi-infintie variables (infiniteopt#359
Browse files Browse the repository at this point in the history
)
  • Loading branch information
pulsipher authored Aug 2, 2024
1 parent c64333f commit 59a8650
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 86 deletions.
19 changes: 18 additions & 1 deletion src/TranscriptionOpt/measures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,24 @@ function InfiniteOpt.add_semi_infinite_variable(
# make the reference and map it to a transcription variable
rvref = InfiniteOpt.GeneralVariableRef(inf_model, raw_index, InfiniteOpt.SemiInfiniteVariableIndex)
push!(semi_infinite_vars, var)
_set_semi_infinite_variable_mapping(backend, var, rvref, InfiniteOpt._index_type(ivref))
if ivref.index_type != InfiniteOpt.ParameterFunctionIndex
ivref_param_nums = InfiniteOpt._parameter_numbers(ivref)
param_nums = var.parameter_nums
supp_indices = support_index_iterator(backend, var.group_int_idxs)
lookup_dict = Dict{Vector{Float64}, JuMP.VariableRef}()
sizehint!(lookup_dict, length(supp_indices))
for i in supp_indices
raw_supp = index_to_support(backend, i)
if any(!isnan(raw_supp[ivref_param_nums[k]]) && raw_supp[ivref_param_nums[k]] != v for (k, v) in eval_supps)
continue
end
ivref_supp = [haskey(eval_supps, j) ? eval_supps[j] : raw_supp[k]
for (j, k) in enumerate(ivref_param_nums)]
supp = raw_supp[param_nums]
lookup_dict[supp] = lookup_by_support(ivref, backend, ivref_supp)
end
data.infvar_lookup[rvref] = lookup_dict
end
data.semi_lookup[(ivref, eval_supps)] = rvref
return rvref
end
Expand Down
116 changes: 48 additions & 68 deletions src/TranscriptionOpt/transcribe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function transcribe_infinite_variables!(
vrefs = Array{JuMP.VariableRef, length(dims)}(undef, dims...)
supp_type = typeof(Tuple(ones(length(prefs)), prefs))
supps = Array{supp_type, length(dims)}(undef, dims...)
lookup_dict = Dict{Vector{Float64}, JuMP.VariableRef}()
lookup_dict = sizehint!(Dict{Vector{Float64}, JuMP.VariableRef}(), length(vrefs))
# create a variable for each support
for i in supp_indices
supp = index_to_support(backend, i)[param_nums]
Expand Down Expand Up @@ -190,7 +190,7 @@ function _transcribe_derivative_variable(dref, d, backend)
vrefs = Array{JuMP.VariableRef, length(dims)}(undef, dims...)
supp_type = typeof(Tuple(ones(length(prefs)), prefs))
supps = Array{supp_type, length(dims)}(undef, dims...)
lookup_dict = Dict{Vector{Float64}, JuMP.VariableRef}()
lookup_dict = sizehint!(Dict{Vector{Float64}, JuMP.VariableRef}(), length(vrefs))
# create a variable for each support
for i in supp_indices
supp = index_to_support(backend, i)[param_nums]
Expand Down Expand Up @@ -255,69 +255,6 @@ function transcribe_derivative_variables!(
return
end

# Setup the mapping for a given semi_infinite variable
function _set_semi_infinite_variable_mapping(
backend::TranscriptionBackend,
var::InfiniteOpt.SemiInfiniteVariable,
rvref::InfiniteOpt.GeneralVariableRef,
index_type
)
param_nums = var.parameter_nums
ivref = var.infinite_variable_ref
ivref_param_nums = InfiniteOpt._parameter_numbers(ivref)
eval_supps = var.eval_supports
group_idxs = var.group_int_idxs
prefs = InfiniteOpt.raw_parameter_refs(var)
# prepare for iterating over its supports
supp_indices = support_index_iterator(backend, group_idxs)
dims = size(supp_indices)[group_idxs]
vrefs = Array{JuMP.VariableRef, length(dims)}(undef, dims...)
supp_type = typeof(Tuple(ones(length(prefs)), prefs))
supps = Array{supp_type, length(dims)}(undef, dims...)
lookup_dict = Dict{Vector{Float64}, JuMP.VariableRef}()
valid_idxs = ones(Bool, dims...)
# map a variable for each support
for i in supp_indices
raw_supp = index_to_support(backend, i)
var_idx = i.I[group_idxs]
# ensure this support is valid with the reduced restriction
if any(!isnan(raw_supp[ivref_param_nums[k]]) && raw_supp[ivref_param_nums[k]] != v for (k, v) in eval_supps)
valid_idxs[var_idx...] = false
continue
end
# map to the current transcription variable
supp = raw_supp[param_nums]
ivref_supp = [haskey(eval_supps, j) ? eval_supps[j] : raw_supp[k]
for (j, k) in enumerate(ivref_param_nums)]
jump_vref = lookup_by_support(ivref, backend, ivref_supp)
@inbounds vrefs[var_idx...] = jump_vref
lookup_dict[supp] = jump_vref
@inbounds supps[var_idx...] = Tuple(supp, prefs)
end
# truncate vrefs if any supports were skipped because of dependent parameter supps and save
data = transcription_data(backend)
if !all(valid_idxs)
data.infvar_mappings[rvref] = vrefs[valid_idxs]
data.infvar_supports[rvref] = supps[valid_idxs]
data.valid_indices[rvref] = valid_idxs
else
data.infvar_mappings[rvref] = vrefs
data.infvar_supports[rvref] = supps
end
data.infvar_lookup[rvref] = lookup_dict
return
end

# Empty mapping dispatch for infinite parameter functions
function _set_semi_infinite_variable_mapping(
backend::TranscriptionBackend,
var::InfiniteOpt.SemiInfiniteVariable,
rvref::InfiniteOpt.GeneralVariableRef,
index_type::Type{InfiniteOpt.ParameterFunctionIndex}
)
return
end

"""
transcribe_semi_infinite_variables!(
backend::TranscriptionBackend,
Expand All @@ -342,8 +279,51 @@ function transcribe_semi_infinite_variables!(
var = object.variable
rvref = InfiniteOpt.GeneralVariableRef(model, idx)
# setup the mappings
idx_type = InfiniteOpt._index_type(InfiniteOpt.infinite_variable_ref(rvref))
_set_semi_infinite_variable_mapping(backend, var, rvref, idx_type)
ivref = var.infinite_variable_ref
if InfiniteOpt._index_type(ivref) != InfiniteOpt.ParameterFunctionIndex
param_nums = var.parameter_nums
ivref_param_nums = InfiniteOpt._parameter_numbers(ivref)
eval_supps = var.eval_supports
group_idxs = var.group_int_idxs
prefs = InfiniteOpt.raw_parameter_refs(var)
# prepare for iterating over its supports
supp_indices = support_index_iterator(backend, group_idxs)
dims = size(supp_indices)[group_idxs]
vrefs = Array{JuMP.VariableRef, length(dims)}(undef, dims...)
supp_type = typeof(Tuple(ones(length(prefs)), prefs))
supps = Array{supp_type, length(dims)}(undef, dims...)
lookup_dict = sizehint!(Dict{Vector{Float64}, JuMP.VariableRef}(), length(vrefs))
valid_idxs = ones(Bool, dims...)
# map a variable for each support
for i in supp_indices
raw_supp = index_to_support(backend, i)
var_idx = i.I[group_idxs]
# ensure this support is valid with the reduced restriction
if any(!isnan(raw_supp[ivref_param_nums[k]]) && raw_supp[ivref_param_nums[k]] != v for (k, v) in eval_supps)
valid_idxs[var_idx...] = false
continue
end
# map to the current transcription variable
supp = raw_supp[param_nums]
ivref_supp = [haskey(eval_supps, j) ? eval_supps[j] : raw_supp[k]
for (j, k) in enumerate(ivref_param_nums)]
jump_vref = lookup_by_support(ivref, backend, ivref_supp)
@inbounds vrefs[var_idx...] = jump_vref
lookup_dict[supp] = jump_vref
@inbounds supps[var_idx...] = Tuple(supp, prefs)
end
# truncate vrefs if any supports were skipped because of dependent parameter supps and save
data = transcription_data(backend)
if !all(valid_idxs)
data.infvar_mappings[rvref] = vrefs[valid_idxs]
data.infvar_supports[rvref] = supps[valid_idxs]
data.valid_indices[rvref] = valid_idxs
else
data.infvar_mappings[rvref] = vrefs
data.infvar_supports[rvref] = supps
end
data.infvar_lookup[rvref] = lookup_dict
end
end
return
end
Expand Down Expand Up @@ -570,7 +550,7 @@ function transcribe_measures!(
exprs = Array{JuMP.AbstractJuMPScalar, length(dims)}(undef, dims...)
supp_type = typeof(Tuple(ones(length(prefs)), prefs))
supps = Array{supp_type, length(dims)}(undef, dims...)
lookup_dict = Dict{Vector{Float64}, Int}()
lookup_dict = sizehint!(Dict{Vector{Float64}, Int}(), length(exprs))
# map a variable for each support
for (lin_idx, i) in enumerate(supp_indices)
raw_supp = index_to_support(backend, i)
Expand Down
16 changes: 10 additions & 6 deletions test/TranscriptionOpt/measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@
vref = GeneralVariableRef(m, -1, SemiInfiniteVariableIndex)
@test isequal(InfiniteOpt.add_semi_infinite_variable(tb, var), vref)
@test isequal(data.semi_infinite_vars, [var])
@test c in IOTO.transcription_variable(vref)
@test d in IOTO.transcription_variable(vref)
@test sort!(supports(vref)) == [([0., 0.], ), ([1., 1.], )]
@test IOTO.transcription_expression(vref, tb, [1., 0., 0.]) == c
@test IOTO.transcription_expression(vref, tb, [1., 1., 1.]) == d
# add one that has already been added internally
@test isequal(InfiniteOpt.add_semi_infinite_variable(tb, var), vref)
@test isequal(data.semi_infinite_vars, [var])
@test c in IOTO.transcription_variable(vref)
@test d in IOTO.transcription_variable(vref)
@test sort!(supports(vref)) == [([0., 0.], ), ([1., 1.], )]
@test IOTO.transcription_expression(vref, tb, [1., 0., 0.]) == c
@test IOTO.transcription_expression(vref, tb, [1., 1., 1.]) == d
# test with partially evaluated dependent parameter group
var2 = SemiInfiniteVariable(y, Dict(1 => 1., 2 => 1.0), [3], [2])
vref = GeneralVariableRef(m, -2, SemiInfiniteVariableIndex)
@test isequal(InfiniteOpt.add_semi_infinite_variable(tb, var2), vref)
@test isequal(data.semi_infinite_vars, [var, var2])
@test IOTO.transcription_expression(vref, tb, [1., 1., 1.]) == d
end
end
12 changes: 1 addition & 11 deletions test/TranscriptionOpt/transcribe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,11 @@
@test supports(dx) == [(0,), (1,)]
@test supports(dy) == [(0, [0, 0]) (0, [1, 1]); (1, [0, 0]) (1, [1, 1])]
end
# test _set_semi_infinite_variable_mapping
@testset "_set_semi_infinite_variable_mapping" begin
var = SemiInfiniteVariable(y, Dict{Int, Float64}(1 => 0), [1, 2], [1])
vref = GeneralVariableRef(m, -1, SemiInfiniteVariableIndex)
@test IOTO._set_semi_infinite_variable_mapping(tb, var, vref, SemiInfiniteVariableIndex) isa Nothing
@test IOTO.transcription_variable(vref) isa Vector{VariableRef}
@test length(IOTO.transcription_data(tb).infvar_mappings) == 7
@test IOTO.lookup_by_support(y, tb, [0., 0, 0]) == IOTO.lookup_by_support(vref, tb, [0., 0])
@test IOTO._set_semi_infinite_variable_mapping(tb, var, vref, ParameterFunctionIndex) isa Nothing
end
# test transcribe_semi_infinite_variables!
@testset "transcribe_semi_infinite_variables!" begin
@test IOTO.transcribe_semi_infinite_variables!(tb, m) isa Nothing
@test IOTO.transcription_variable(yrv) isa Vector{VariableRef}
@test length(IOTO.transcription_data(tb).infvar_mappings) == 8
@test length(IOTO.transcription_data(tb).infvar_mappings) == 7
@test IOTO.lookup_by_support(y, tb, [1., 0., 0.]) == IOTO.lookup_by_support(yrv, tb, [1., 0])
end
# test _update_point_info
Expand Down

0 comments on commit 59a8650

Please sign in to comment.