Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce macro to easily create custom layers #4

Merged
merged 19 commits into from
Feb 24, 2023

Conversation

MilesCranmer
Copy link
Contributor

@MilesCranmer MilesCranmer commented Feb 10, 2023

Edits:

  • Feb 12: added name keyword for naming custom layers.
  • Feb 14: changed macro name from @Magic to @compact. The macro is also no longer exported, so Flux.@compact should be used in the future.
  • Feb 14: overloaded Flux's printing function so that this is printed in a similar style to Chain, Dense, etc., with parameter counts listed and re-digestability.

Introduction

This creates the @compact macro to easily allow building of complex layers without needing to first create a struct. It is completely compatible with Flux.Chain. The @compact macro specifically was contributed by @mcabbott following a lengthy discussion in FluxML/Flux.jl#2107 between us as well as @ToucheSir @darsnack. This code is copied here, along with some basic unit tests.

Here are some examples:

Linear model:

r = @compact(w = [1, 2, 3]) do x
    w .* x
end
r([1, 1, 1]) # = [1, 2, 3]

Here is a linear model with bias and activation:

d = @compact(in=5, out=7, W=randn(out, in), b=zeros(out), act=relu) do x
    y = W * x
    act.(y .+ b)
end
d(ones(5, 10))  # 7×10 Matrix as output.

Finally, here is a simple MLP:

using Flux

n_in = 1
n_out = 1
nlayers = 3

model = @compact(
    w1=Dense(n_in, 128),
    w2=[Dense(128, 128) for i=1:nlayers],
    w3=Dense(128, n_out),
    act=relu
) do x
    embed = act(w1(x))
    for w in w2
        embed = act(w(embed))
    end
    out = w3(embed)
    return out
end

model(randn(n_in, 32))  # 1×32 Matrix as output.

We can train this model just like any Chain:

data = [([x], 2x-x^3) for x in -2:0.1f0:2]
optim = Flux.setup(Adam(), model)

for epoch in 1:1000
    Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
end

Discussion and Motivation

To see detailed discussion on this idea, please see threads FluxML/Flux.jl#2107 and #2.

The key motivation is that, while Chain is a really nice way to build many different complex layers in Flux.jl, it is sometimes significantly easier to write down models as forward functions in regular ol' code.

Most popular deep learning frameworks in Python have a simple and extensible API for creating complex neural network layers, such as PyTorch:

class Net(nn.Module):
    def __init__(self, in, out, act):
        super().__init__()
        self.w1 = nn.Linear(in, 100)
        self.w2 = nn.Linear(100, out)
        self.act = act
    def forward(self, x):
        return self.w2(self.act(self.w1(x)))

net = Net(1, 1, F.relu)

where the forward function is a regular Python function that allows arbitrary code (i.e., not Sequential/Chain). However, Flux.jl does not have something like this. The equivalent Flux implementation (without using Chain) would be:

struct Net
    w1::Dense
    w2::Dense
    act::Function
end

function (r::Net)(x)
    return r.w2(r.act(r.w1(x)))
end

@functor Net

function Net(; in, out, act)
    Net(Dense(in, 100), Dense(100, out), act)
end

net = Net(1, 1, relu)

Compounding the difficulty is the fact that Julia structs cannot be changed without restarting the runtime. So if you are interactively developing a complex neural net, you can't add new parameters to the Net struct without restarting.

This simple @compact macro makes this all go away. Now it's even simpler to build custom layers in Flux.jl than in PyTorch:

function Net(; in, out, act)
    @compact(w1=Dense(in, 100), w2=Dense(100, out)) do x
        w2(act(w1(x)))
    end
end

net = Net(1, 1, relu)

or even, for building things quickly,

net = @compact(w1=Dense(1, 100), w2=Dense(100, 1)) do x
    w2(relu(w1(x)))
end

This @compact macro is completely compatible with the existing Flux API such as Chain, so is an easy way to build complex layers inside larger models.

PR Checklist

  • Tests are added
  • Documentation, if applicable

1-to-1 comparison:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(1, 100)
        self.w2 = nn.Linear(100, 1)
    def forward(self, x):
        x = F.relu(self.w1(x))
        return self.w2(x)

net = Net()
net = @compact(w1=Dense(1, 100), w2=Dense(100, 1)) do x
    x = relu(w1(x))
    return w2(x)
end

@MilesCranmer MilesCranmer changed the title Create @Magic to easily create custom layers Introduce @Magic macro to easily create custom layers Feb 11, 2023
@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Feb 11, 2023

The test error looks to be with split_join.jl rather than this? @mcabbott


Edit nevermind, looks like a bug in the nightly version of Julia.

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Feb 11, 2023

It might be nice to add a reserved keyword that defines the string representation. Rather than, e.g.,

julia> print(model)
@Magic(w1 = Dense(n_in, 128), w2 = [Dense(128, 128) for i = 1:nlayers], w3 = Dense(128, n_out), act = relu) do x
    embed = act(w1(x))
    for w = w2
        embed = act(w(embed))
    end
    out = w3(embed)
    return out
end

which might be too much detail when a simple function name would do, it might be nice to allow for

model = @Magic(..., name="MLP") ...

so that

julia> print(model)
MLP(w1 = Dense(n_in, 128), w2 = [Dense(128, 128) for i = 1:nlayers], w3 = Dense(128, n_out), act = relu)

which would make it a way to define new custom layers for Flux.Chain objects.

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Feb 11, 2023

Added the name keyword!

model = @Magic(w=randn(32, 32), name="Linear") do x, y
  tmp = sum(w .* x)
  return tmp + y
end
@test string(model) == "Linear(w = randn(32, 32))"

The default string representation is still the verbatim printout of the definition. But if you would like to name the model, such as if you are using this API to build up layers that you want to stack in a Flux.Chain, I think this is very practical.

@mcabbott
Copy link
Member

Will look more later, but I have one quick printing comment. The goal in most of Flux is to make the printing re-digestable. name="Linear" breaks this, since there is no Linear object defined. Thus maybe it should print say Linear(...), something which is a syntax error, so that you can't even try.

@MilesCranmer
Copy link
Contributor Author

The goal in most of Flux is to make the printing re-digestable.

A user could choose to define the name to make it re-digestible in their library of custom layers:

function RectifiedLinear(; n_in, n_out)
    name = "RectifiedLinear(; n_in=$(n_in), n_out=$(n_out))"
    @eval @Magic(w=Dense($n_in, $n_out), name=$name) do x
        relu(w(x))
    end
end

which gives us:

julia> RectifiedLinear(; n_in=3, n_out=5)
RectifiedLinear(; n_in=3, n_out=5)

I'm not sure it is practical to always print out the full string representation used to construct custom layers, especially when things get very complex – otherwise you would be printing out an entire codebase. But still you could choose to do so by not setting name.

@MilesCranmer
Copy link
Contributor Author

Tweaked the printing a bit:

  1. Now prints the exact string passed to the macro. A user can customize to their liking:
model = @Magic(w=randn(32, 32), name="Linear(...)") do x, y
  tmp = sum(w .* x)
  return tmp + y
end
@test string(model) == "Linear(...)"
  1. The printout for unnamed macro layers now prints out one argument per line, which is good for when many inputs are given:
julia> m = @Magic(w1=Dense(5 => 100, relu), w2=Dense(100 => 1)) do x
           w2(w1(x))
       end
@Magic(
  w1 = Dense(5 => 100, relu),
  w2 = Dense(100 => 1),
) do x
    w2(w1(x))
end

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Feb 12, 2023

It might also be nice to rewrite Base.show(..., m::MagicLayer) so that it looks at the variables, and, if they are in the Flux DSL, just print them normally. But leave other strings untouched.

That way, you could have printouts like this:

julia> m = @Magic(mlp=Chain(Dense(32 => 100, relu), Dense(100 => 1)), offset=randn(32)) do x
           mlp(x .+ offset)
       end;
julia> println(m)
@Magic(
  mlp=Chain(
    Dense(32 => 100, relu),               # 3_300 parameters
    Dense(100 => 1),                      # 101 parameters
  ),
  offset=randn(32),                       # 32 parameters
) do x
   mlp(x .+ offset)
end

@mcabbott do you know if there's a way to check if something has been @functor-ized? That could be a way to check whether the regular print method should be used or not.

@MilesCranmer
Copy link
Contributor Author

@marius311 proposed tweaking the name (which I am in favor of!) and suggested @withparams. What do others think?

I wonder if something with let in it would be a good idea, because the declaration is similar to a let statement. e.g., @letmodel?

@darsnack
Copy link
Member

@compactmodel would invoke a similarity to flax.linen.@compact for folks coming from that space

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Feb 13, 2023

That's a good one too. Or maybe just Flux.@compact since the namespace offers context?


Also, should the macro be capitalized because it's creating an object?

@MilesCranmer MilesCranmer changed the title Introduce @Magic macro to easily create custom layers Introduce macro to easily create custom layers Feb 13, 2023
@darsnack
Copy link
Member

@compact is good as long as it isn't exported. I would say helpers that return objects are still lowercase (which is this case here). Uppercase should be reserved for types and constructors.

Project.toml Outdated Show resolved Hide resolved
src/magic.jl Outdated Show resolved Hide resolved
src/magic.jl Outdated Show resolved Hide resolved
src/magic.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

I think you could do something like https://github.com/FluxML/Metalhead.jl/blob/master/src/Metalhead.jl if you want to participate in the Flux fancy show methods.

@MilesCranmer
Copy link
Contributor Author

I have now renamed the macro to be @compact, and removed the export.

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Feb 14, 2023

@darsnack @mcabbott I updated the printing to overload _big_show with custom logic for CompactLayer. Now it works pretty nicely; check it out:

julia> model = @compact(w1=Dense(32, 32, relu), w2=Dense(32, 32)) do x
         w2(w1(x))
       end
@compact(
  w1 = Dense(32 => 32, relu),           # 1_056 parameters
  w2 = Dense(32 => 32),                 # 1_056 parameters
) do x
    w2(w1(x))
end                  # Total: 4 arrays, 2_112 parameters, 8.602 KiB.

It also works inside other Flux models:

julia> Chain(model, Dense(32, 32))
Chain(
  @compact(
    w1 = Dense(32 => 32, relu),         # 1_056 parameters
    w2 = Dense(32 => 32),               # 1_056 parameters
  ) do x 
      w2(w1(x))
  end,
  Dense(32 => 32),                      # 1_056 parameters
)                   # Total: 6 arrays, 3_168 parameters, 12.961 KiB.

Or even with a hierarchy of @compact:

julia> model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
         w2(w1(x))
       end;
julia> model2 = @compact(w1=model1, w2=Dense(32=>32, relu)) do x
         w2(w1(x))
       end
@compact(
  w1 = @compact(
    w1 = Dense(32 => 32, relu),         # 1_056 parameters
    w2 = Dense(32 => 32, relu),         # 1_056 parameters
  ) do x 
      w2(w1(x))
  end,
  w2 = Dense(32 => 32, relu),           # 1_056 parameters
) do x 
    w2(w1(x))
end                  # Total: 6 arrays, 3_168 parameters, 13.047 KiB.

This is re-digestable too! (For the most part, unless you start passing arrays of Dense, then it prints it as Array(Dense(32, 32), Dense(32, 32), ...) rather than [Dense(32, 32), Dense(32, 32), ...] - maybe needs to be patched in Flux.jl.)

Another difficulty is that now, w1 = randn(32) is not printed verbatim. Rather, all 32 values are printed. This seems like a larger change, as we would still want to include the parameter count in addition to printing "randn(32)". We can fix this later.

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Feb 14, 2023

Got arrays working too 🎉

julia> model = @compact(x=randn(5), w=Dense(32=>32)) do s
           x .* s
       end;
julia> model
@compact(
  x = randn(5),                         # 5 parameters
  w = Dense(32 => 32),                  # 1_056 parameters
) do s 
    x .* s
end                  # Total: 3 arrays, 1_061 parameters, 4.527 KiB.

if get(io, :typeinfo, nothing) === nothing # e.g., top level of REPL
Flux._big_show(io, m)
elseif !get(io, :compact, false) # e.g., printed inside a Vector, but not a matrix
Flux._layer_show(io, m)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be overloaded too? What is the difference in _layer_show?

- This is because the size depends on the indentation of the model,
   which might change in the future (and result in confusing errors!)
@MilesCranmer
Copy link
Contributor Author

Hey @darsnack @mcabbott if you have a chance this week do you think you might be able to review this PR? Thanks! - Miles

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the spirit of Fluxperimental.jl being somewhere to explore and possibly break new APIs, I think we should go ahead. Thanks @MilesCranmer, will merge tomorrow morning if there are no further objections.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No objections from me! Really appreciate this big effort.

@ToucheSir ToucheSir merged commit d917e17 into FluxML:master Feb 24, 2023
@MilesCranmer
Copy link
Contributor Author

Awesome, thanks!

Also, could I ask for what would be an expected timeline, or milestones required, for this to eventually join the Flux.jl tree? I’ve been using it for local stuff and it’s extremely useful and helped me get started quicker.

@MilesCranmer
Copy link
Contributor Author

Ping regarding my question above 🙂

Would be fantastic to have this in the normal Flux.jl library. Others even mentioned they would be interested in using it for non-deep learning tasks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants