Skip to content

Commit

Permalink
Merge pull request #702 from LuxDL/rebase_docs
Browse files Browse the repository at this point in the history
Add activation functions doc reference (Rebase #694)
  • Loading branch information
avik-pal authored Jun 15, 2024
2 parents 60ff714 + 048c987 commit 163098a
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 2 deletions.
8 changes: 6 additions & 2 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ export default defineConfig({
text: 'Building Blocks', items: [
{ text: 'LuxCore', link: '/api/Building_Blocks/LuxCore' },
{ text: 'LuxLib', link: '/api/Building_Blocks/LuxLib' },
{ text: 'WeightInitializers', link: '/api/Building_Blocks/WeightInitializers' }
{ text: 'WeightInitializers', link: '/api/Building_Blocks/WeightInitializers' },
{ text: 'NNlib', link: 'https://fluxml.ai/NNlib.jl/dev/' },
{ text: 'Activation Functions', link: 'https://fluxml.ai/NNlib.jl/dev/reference/#Activation-Functions' }
]
},
{
Expand Down Expand Up @@ -203,7 +205,9 @@ export default defineConfig({
text: 'Building Blocks', collapsed: false, items: [
{ text: 'LuxCore', link: '/api/Building_Blocks/LuxCore' },
{ text: 'LuxLib', link: '/api/Building_Blocks/LuxLib' },
{ text: 'WeightInitializers', link: '/api/Building_Blocks/WeightInitializers' }]
{ text: 'WeightInitializers', link: '/api/Building_Blocks/WeightInitializers' },
{ text: 'NNlib', link: 'https://fluxml.ai/NNlib.jl/dev/' },
{ text: 'Activation Functions', link: 'https://fluxml.ai/NNlib.jl/dev/reference/#Activation-Functions' }]
},
{
text: 'Domain Specific Modeling', collapsed: false, items: [
Expand Down
48 changes: 48 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ Reshapes the passed array to have a size of `(dims..., :)`
- AbstractArray of size `(dims..., size(x, ndims(x)))`
- Empty `NamedTuple()`
## Example
```jldoctest
julia> model = ReshapeLayer((2, 2))
ReshapeLayer(output_dims = (2, 2, :))
julia> rng = Random.default_rng();
Random.seed!(rng, 0);
ps, st = Lux.setup(rng, model);
x = randn(rng, Float32, (4, 1, 3));
julia> y, st_new = model(x, ps, st);
size(y)
(2, 2, 3)
```
"""
struct ReshapeLayer{N} <: AbstractExplicitLayer
dims::NTuple{N, Int}
Expand Down Expand Up @@ -48,6 +64,22 @@ Flattens the passed array into a matrix.
- AbstractMatrix of size `(:, size(x, ndims(x)))`
- Empty `NamedTuple()`
## Example
```jldoctest
julia> model = FlattenLayer()
FlattenLayer()
julia> rng = Random.default_rng();
Random.seed!(rng, 0);
ps, st = Lux.setup(rng, model);
x = randn(rng, Float32, (2, 2, 2, 2));
julia> y, st_new = model(x, ps, st);
size(y)
(8, 2)
```
"""
@kwdef @concrete struct FlattenLayer <: AbstractExplicitLayer
N = nothing
Expand Down Expand Up @@ -100,6 +132,22 @@ end
As the name suggests does nothing but allows pretty printing of layers. Whatever input is
passed is returned.
# Example
```jldoctest
julia> model = NoOpLayer()
NoOpLayer()
julia> rng = Random.default_rng();
Random.seed!(rng, 0);
ps, st = Lux.setup(rng, model);
x = 1
1
julia> y, st_new = model(x, ps, st)
(1, NamedTuple())
```
"""
struct NoOpLayer <: AbstractExplicitLayer end

Expand Down
20 changes: 20 additions & 0 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ with `connection`.
`fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API)
See also [`SkipConnection`](@ref) which is `Parallel` with one identity.
## Example
```jldoctest
julia> model = Parallel(nothing, Dense(2, 1), Dense(2, 1))
Parallel(
layer_1 = Dense(2 => 1), # 3 parameters
layer_2 = Dense(2 => 1), # 3 parameters
) # Total: 6 parameters,
# plus 0 states.
julia> using Random;
rng = Random.seed!(123);
ps, st = Lux.setup(rng, model);
x1 = randn(rng, Float32, 2);
x2 = randn(rng, Float32, 2);
julia> size.(first(model((x1, x2), ps, st)))
((1,), (1,))
```
"""
@concrete struct Parallel{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)}
connection
Expand Down

1 comment on commit 163098a

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 163098a Previous: 60ff714 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3696.875 ns 3713.125 ns 1.00
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7233.5 ns 7203.333333333333 ns 1.00
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20908 ns 20819 ns 1.00
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9790.4 ns 9780.75 ns 1.00
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9125 ns 9093 ns 1.00
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4518.375 ns 4467.125 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1171.1726618705036 ns 1159.7676056338028 ns 1.01
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1173.6119402985075 ns 1169.8357664233577 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1187.3636363636363 ns 1189.0078125 ns 1.00
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1788.8620689655172 ns 1776.9912280701756 ns 1.01
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.96897038081806 ns 180.60507757404795 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17242 ns 17362 ns 0.99
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16801 ns 16872 ns 1.00
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 39254 ns 39698.5 ns 0.99
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29134 ns 29275 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20038 ns 20047 ns 1.00
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17633 ns 17413 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4310.857142857143 ns 4328.785714285714 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3817.125 ns 3864.75 ns 0.99
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3949.875 ns 3916.125 ns 1.01
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4780.285714285715 ns 4892 ns 0.98
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1658.1 ns 1654.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 46801623 ns 39421463 ns 1.19
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57394014 ns 57774528 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 96959842 ns 72442185 ns 1.34
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 101748704 ns 89245855 ns 1.14
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 106543968 ns 73071392 ns 1.46
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11913745.5 ns 12092880 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 17580360 ns 17868830.5 ns 0.98
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 6976389 ns 7042694 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6936975.5 ns 7026994 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 18128983 ns 10167053.5 ns 1.78
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6378775 ns 6398143 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 720026710 ns 733077817 ns 0.98
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2535400202 ns 2576408758 ns 0.98
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 127145543 ns 145382696 ns 0.87
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 940372357.5 ns 793204426 ns 1.19
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3626594533 ns 2934972747 ns 1.24
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 212597349.5 ns 200283848.5 ns 1.06
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 725880434.5 ns 657077422.5 ns 1.10
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2880495088 ns 2630726384.5 ns 1.09
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 145177630 ns 125267603 ns 1.16
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 172990409 ns 174556269.5 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 640453775.5 ns 655564353.5 ns 0.98
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 33890669 ns 34840242 ns 0.97
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164117913 ns 165209849.5 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 638052367.5 ns 639944606 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29688588 ns 30130274.5 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 204705878 ns 186248429.5 ns 1.10
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 890135307 ns 716769749 ns 1.24
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 37377804.5 ns 35937506 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1215524499 ns 1212793695 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1855961983.5 ns 1876991316.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2369632518 ns 2315042302 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2493773538 ns 2546229569 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1954896920.5 ns 1829086938.5 ns 1.07
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 559179805 ns 562582630 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 315722092 ns 322654245 ns 0.98
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 316975900 ns 324527121 ns 0.98
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 446182933 ns 368705304 ns 1.21
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11872676 ns 12030571 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17732771 ns 17884778.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 18978200 ns 19210928.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23685126 ns 23885393.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17760726.5 ns 17872013 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1158445 ns 1165770.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5756319 ns 5886802 ns 0.98
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2040243.5 ns 2059749 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2022836 ns 2042150 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2062329.5 ns 2090230 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 199142 ns 205073 ns 0.97
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 291514 ns 295001 ns 0.99
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 264635 ns 266698 ns 0.99
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 362166 ns 370442 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 405158 ns 411668 ns 0.98
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 272480 ns 276035.5 ns 0.99
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 405398 ns 410075 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83125 ns 83606 ns 0.99
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81131 ns 81742 ns 0.99
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81042 ns 82474 ns 0.98
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86091 ns 87192 ns 0.99
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104505 ns 104675 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 208699675 ns 189862618.5 ns 1.10
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 319819436.5 ns 323925648.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 440840736 ns 396072434 ns 1.11
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 502066159 ns 457714601.5 ns 1.10
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 405717557 ns 374592898 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 318284871.5 ns 346883743 ns 0.92
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 99765515 ns 101249018.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43558093 ns 43995641 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43439741 ns 43836879 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 70444994 ns 60056689 ns 1.17
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28060129 ns 28759179 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18601594 ns 19178597 ns 0.97
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19357835 ns 19643348.5 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23024994.5 ns 23527065 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 23879952 ns 24189371 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19459937 ns 19709627 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6479553.5 ns 6556235 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6476907 ns 6556525 ns 0.99
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6465645 ns 6516631 ns 0.99
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6493170 ns 6542791.5 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.