Skip to content

Commit

Permalink
Use New JuMP Macro API (infiniteopt#334)
Browse files Browse the repository at this point in the history
* Add new JuMP macro support

* add Pkg

* restore empty macro tests

* use name_with_index_expr

* revert to _error and imporve test coverage

* update to build_name_expr

* fix accidental name changes

* Update JuMP branch

* prepare for merging

* add test
  • Loading branch information
pulsipher authored Feb 1, 2024
1 parent ad13a33 commit f1b41c2
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 402 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
DataStructures = "0.14.2 - 0.18"
Distributions = "0.19 - 0.25"
FastGaussQuadrature = "0.3.2 - 0.4, 0.5, 1"
JuMP = "1.5"
JuMP = "1.18"
MutableArithmetics = "1"
Reexport = "0.2, 1"
julia = "^1.6"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random", "Suppressor"]
test = ["Pkg", "Test", "Random", "Suppressor"]
42 changes: 22 additions & 20 deletions src/MeasureToolbox/expectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,18 @@ end
An efficient wrapper for [`expect`](@ref). Please see its doc string more
information.
"""
macro expect(expr, prefs, args...)
_error(str...) = InfiniteOpt._macro_error(:expect, (expr, prefs, args...),
str...)
extra, kwargs, _, _ = InfiniteOpt._extract_kwargs(args)
if length(extra) > 0
_error("Unexpected positional arguments." *
"Must be of form @expect(expr, prefs, kwargs...).")
end
macro expect(args...)
error_fn = InfiniteOpt.JuMPC.build_error_fn(:expect, args, __source__)
pos_args, kwargs = InfiniteOpt.JuMPC.parse_macro_arguments(
error_fn,
args,
num_positional_args = 2
)
expr, prefs = pos_args
expression = MutableArithmetics.rewrite_and_return(expr)
esc_kwargs = map(i -> esc(i), kwargs)
return :( expect($expression, $(esc(prefs)); ($(esc_kwargs...))) )
code = :( expect($expression, $(esc(prefs))) )
InfiniteOpt.JuMPC.add_additional_args(code, [], kwargs)
return code
end

"""
Expand Down Expand Up @@ -252,15 +253,16 @@ end
A convenient wrapper for [`@expect`](@ref). The unicode symbol `𝔼` is produced by
`\\bbE`.
"""
macro 𝔼(expr, prefs, args...)
_error(str...) = InfiniteOpt._macro_error(:𝔼, (expr, prefs, args...),
str...)
extra, kwargs, _, _ = InfiniteOpt._extract_kwargs(args)
if length(extra) > 0
_error("Unexpected positional arguments." *
"Must be of form @𝔼(expr, prefs, kwargs...).")
end
macro 𝔼(args...)
error_fn = InfiniteOpt.JuMPC.build_error_fn(:𝔼, args, __source__)
pos_args, kwargs = InfiniteOpt.JuMPC.parse_macro_arguments(
error_fn,
args,
num_positional_args = 2
)
expr, prefs = pos_args
expression = MutableArithmetics.rewrite_and_return(expr)
esc_kwargs = map(i -> esc(i), kwargs)
return :( expect($expression, $(esc(prefs)); ($(esc_kwargs...))) )
code = :( expect($expression, $(esc(prefs))) )
InfiniteOpt.JuMPC.add_additional_args(code, [], kwargs)
return code
end
47 changes: 26 additions & 21 deletions src/MeasureToolbox/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1058,18 +1058,20 @@ An efficient wrapper for [`integral`](@ref integral(::JuMP.AbstractJuMPScalar, :
and [`integral`](@ref integral(::JuMP.AbstractJuMPScalar, ::AbstractArray{InfiniteOpt.GeneralVariableRef}, ::Union{Real, AbstractArray{<:Real}}, ::Union{Real, AbstractArray{<:Real}})).
Please see the above doc strings for more information.
"""
macro integral(expr, prefs, args...)
_error(str...) = InfiniteOpt._macro_error(:integral, (expr, prefs, args...),
str...)
extra, kwargs, _, _ = InfiniteOpt._extract_kwargs(args)
if length(extra) != 0 && length(extra) != 2
_error("Incorrect number of positional arguments for @integral. " *
"Must provide both bounds or no bounds.")
end
macro integral(args...)
error_fn = InfiniteOpt.JuMPC.build_error_fn(:expect, args, __source__)
pos_args, kwargs = InfiniteOpt.JuMPC.parse_macro_arguments(
error_fn,
args,
num_positional_args = 2:4
)
length(pos_args) == 3 && error_fn("Must specify both bounds.")
expr = popfirst!(pos_args)
prefs = popfirst!(pos_args)
expression = MutableArithmetics.rewrite_and_return(expr)
esc_kwargs = map(i -> esc(i), kwargs)
esc_extra = map(i -> esc(i), extra)
return :( integral($expression, $(esc(prefs)), $(esc_extra...); ($(esc_kwargs...))) )
code = :( integral($expression, $(esc(prefs))) )
InfiniteOpt.JuMPC.add_additional_args(code, pos_args, kwargs)
return code
end

"""
Expand All @@ -1082,15 +1084,18 @@ end
A convenient wrapper for [`@integral`](@ref). The unicode symbol `∫` is produced
via `\\int`.
"""
macro (expr, prefs, args...)
_error(str...) = InfiniteOpt._macro_error(:∫, (expr, prefs, args...), str...)
extra, kwargs, _, _ = InfiniteOpt._extract_kwargs(args)
if length(extra) != 0 && length(extra) != 2
_error("Incorrect number of positional arguments for @integral. " *
"Must provide both bounds or no bounds.")
end
macro (args...)
error_fn = InfiniteOpt.JuMPC.build_error_fn(:∫, args, __source__)
pos_args, kwargs = InfiniteOpt.JuMPC.parse_macro_arguments(
error_fn,
args,
num_positional_args = 2:4
)
length(pos_args) == 3 && error_fn("Must specify both bounds.")
expr = popfirst!(pos_args)
prefs = popfirst!(pos_args)
expression = MutableArithmetics.rewrite_and_return(expr)
esc_kwargs = map(i -> esc(i), kwargs)
esc_extra = map(i -> esc(i), extra)
return :( integral($expression, $(esc(prefs)), $(esc_extra...); ($(esc_kwargs...))) )
code = :( integral($expression, $(esc(prefs))) )
InfiniteOpt.JuMPC.add_additional_args(code, pos_args, kwargs)
return code
end
22 changes: 12 additions & 10 deletions src/MeasureToolbox/support_sums.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,17 @@ end
An efficient wrapper for [`support_sum`](@ref) please see its doc string for
more information.
"""
macro support_sum(expr, prefs, args...)
_error(str...) = InfiniteOpt._macro_error(:support_sum, (expr, prefs, args...),
str...)
extra, kwargs, _, _ = InfiniteOpt._extract_kwargs(args)
if length(extra) > 0
_error("Unexpected positional arguments." *
"Must be of form @support_sum(expr, prefs, kwargs...).")
end
macro support_sum(args...)
error_fn = InfiniteOpt.JuMPC.build_error_fn(:support_sum, args, __source__)
pos_args, kwargs = InfiniteOpt.JuMPC.parse_macro_arguments(
error_fn,
args,
num_positional_args = 2,
valid_kwargs = [:label]
)
expr, prefs = pos_args
expression = MutableArithmetics.rewrite_and_return(expr)
esc_kwargs = map(i -> esc(i), kwargs)
return :( support_sum($expression, $(esc(prefs)); ($(esc_kwargs...))) )
code = :( support_sum($expression, $(esc(prefs))) )
InfiniteOpt.JuMPC.add_additional_args(code, [], kwargs)
return code
end
15 changes: 10 additions & 5 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,16 @@ julia> deriv_expr = @deriv(x^2 + z, t^2)
```
"""
macro deriv(expr, args...)
# process the arugments
extra, kwargs, _, _ = _extract_kwargs(args)
# error if kwargs are given
_error(str...) = _macro_error(:deriv, (expr, args...), __source__, str...)
isempty(kwargs) || _error("Invalid keyword argument given.")
# make an error function
error_fn = JuMPC.build_error_fn(:deriv, (expr, args...), __source__)

# process the inputs
extra, _ = JuMPC.parse_macro_arguments(
error_fn,
args,
valid_kwargs = Symbol[]
)

# expand the parameter references as needed with powers
pref_exprs = []
for p in extra
Expand Down
Loading

0 comments on commit f1b41c2

Please sign in to comment.