Skip to content

Commit

Permalink
Create Magic macro
Browse files Browse the repository at this point in the history
- Contributed by @mcabbot in FluxML/Flux.jl#2107
  • Loading branch information
MilesCranmer committed Feb 10, 2023
1 parent 8650d54 commit 60fb970
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ export Split, Join
include("train.jl")
export shinkansen!

include("magic.jl")
export @Magic

end # module Fluxperimental
114 changes: 114 additions & 0 deletions src/magic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
@Magic(forward::Function; construct...)
Creates a layer by specifying some code to construct the layer, run immediately,
and (usually as a `do` block) a function for the forward pass.
You may think of `construct` as keywords, or better as a `let` block creating local variables.
Their names may be used within the body of the `forward` function.
Here is a linear model:
```
r = @Magic(w = rand(3)) do x
w .* x
end
r([1, 1, 1]) # x is set to [1, 1, 1].
```
Here is a linear model with bias and activation:
```
d = @Magic(in=5, out=7, W=randn(out, in), b=zeros(out), act=relu) do x
y = W * x
act.(y .+ b)
end
d(ones(5, 10)) # 7×10 Matrix as output.
```
Finally, here is a simple MLP:
```
using Flux
n_in = 1
n_out = 1
nlayers = 3
model = @Magic(
w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i=1:nlayers],
w3=Dense(128, n_out),
act=relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
end
out = w3(embed)
return out
end
model(randn(n_in, 32)) # 1×32 Matrix as output.
```
We can train this model just like any `Chain`:
```
data = [([x], 2x-x^3) for x in -2:0.1f0:2]
optim = Flux.setup(Adam(), model)
for epoch in 1:1000
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
end
```
"""
macro Magic(fex, kwexs...)
# check input
Meta.isexpr(fex, :(->)) || error("expects a do block")
isempty(kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, :kw), kwexs) || error("expects only keyword argumens")

# make strings
layer = "@Magic"
setup = join(map(ex -> string(ex.args[1], " = ", ex.args[2]), kwexs), ", ")
input = join(fex.args[1].args, ", ")
block = string(Base.remove_linenums!(fex).args[2])

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
assigns = map(ex -> Expr(:(=), ex.args...), kwexs)
@gensym self
pushfirst!(fex.args[1].args, self)
addprefix!(fex, self, vars)

# assemble
return esc(quote
let
$(assigns...)
$MagicLayer($fex, ($layer, $setup, $input, $block); $(vars...))
end
end)
end

function addprefix!(ex::Expr, self, vars)
for i = 1:length(ex.args)
if ex.args[i] in vars
ex.args[i] = :($self.$(ex.args[i]))
else
addprefix!(ex.args[i], self, vars)
end
end
end
addprefix!(not_ex, self, vars) = nothing

struct MagicLayer{F,NT<:NamedTuple}
fun::F
strings::NTuple{4,String}
variables::NT
end
MagicLayer(f::Function, str::Tuple; kw...) = MagicLayer(f, str, NamedTuple(kw))
(m::MagicLayer)(x...) = m.fun(m.variables, x...)
MagicLayer(args...) = error("MagicLayer is meant to be constructed by the macro")
Flux.@functor MagicLayer

function Base.show(io::IO, m::MagicLayer)
layer, setup, input, block = m.strings
print(io, layer, "(", setup, ") do ", input)
return print(io, block[6:end])
end

0 comments on commit 60fb970

Please sign in to comment.