Skip to content

Commit

Permalink
Improve names and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 1, 2024
1 parent 42a9bae commit 70901bc
Show file tree
Hide file tree
Showing 15 changed files with 226 additions and 76 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ for i in eachindex(x1,x2,x3,x4,y1,y2)
end
```

With this package, we can apply `@fused` to reduce the number of reads and preserve the memory layout:
With this package, we can apply `@fused_direct` to reduce the number of reads and preserve the memory layout:

```julia
import MultiBroadcastFusion as MBF
Expand All @@ -62,7 +62,7 @@ y1 = rand(3,3)
y2 = rand(3,3)

# 4 reads, 2 writes
MBF.@fused begin
MBF.@fused_direct begin
@. y1 = x1 * x2 + x3 * x4
@. y2 = x1 * x3 + x2 * x4
end
Expand All @@ -76,10 +76,11 @@ Users can write custom implementations, using the `@make_type` and `@make_fused`

```julia
import MultiBroadcastFusion as MBF
import MultiBroadcastFusion: fused_direct

MBF.@make_type MyFusedMultiBroadcast
MBF.@make_fused MBF.fused_pairs MyFusedMultiBroadcast my_fused
# Now, `@fused` will call `Base.copyto!(::MyFusedMultiBroadcast)`. Let's define it:
MBF.@make_fused fused_direct MyFusedMultiBroadcast my_fused
# Now, `@fused_direct` will call `Base.copyto!(::MyFusedMultiBroadcast)`. Let's define it:
function Base.copyto!(fmb::MyFusedMultiBroadcast)
pairs = fmb.pairs
destinations = map(x->x.first, pairs)
Expand Down Expand Up @@ -117,7 +118,7 @@ end
macro get_fused_multi_broadcast(expr)
_pairs = gensym()
quote
$_pairs = $(esc(MBF.fused_pairs(expr)))
$_pairs = $(esc(MBF.fused_direct(expr)))
FusedMultiBroadcast($_pairs)
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/MultiBroadcastFusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module MultiBroadcastFusion
include(joinpath("collection", "utils.jl"))
include(joinpath("collection", "macros.jl"))
include(joinpath("collection", "code_lowered_single_expression.jl"))
include(joinpath("collection", "fused_pairs.jl"))
include(joinpath("collection", "fused_pairs_flexible.jl"))
include(joinpath("collection", "fused_direct.jl"))
include(joinpath("collection", "fused_assemble.jl"))

include(joinpath("execution", "fused_kernels.jl"))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,55 +1,91 @@
#####
##### Complex/flexible version
##### Fused assemble
#####

# General case: do nothing (identity)
transform_flex(x, sym) = x
transform_flex(s::Symbol, sym) = s
# Expression: recursively transform_flex for Expr
function transform_flex(e::Expr, sym)
transform_assemble(x, sym) = x
transform_assemble(s::Symbol, sym) = s
# Expression: recursively transform_assemble for Expr
function transform_assemble(e::Expr, sym)
if e.head == :macrocall && e.args[1] == Symbol("@__dot__")
se = code_lowered_single_expression(e)
margs = materialize_args(se)
subexpr = :($sym = ($sym..., Pair($(margs[1]), $(margs[2]))))
subexpr
else
Expr(transform_flex(e.head, sym), transform_flex.(e.args, sym)...)
Expr(
transform_assemble(e.head, sym),
transform_assemble.(e.args, sym)...,
)
end
end

"""
fused_pairs_flexible
fused_assemble(expr::Expr)
Function that fuses broadcast expressions
that stride flow control logic. For example:
Transforms the input expressions
into a runtime assembly of a tuple
of `Pair`s, containing (firsts)
the destination of broadcast expressions
and (seconds) the broadcasted objects.
For example:
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedMultiBroadcast
MBF.@make_fused fused_pairs_flexible MyFusedMultiBroadcast fused_flexible
using Test
expr_in = quote
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
end
expr_out = quote
tup = ()
tup = (tup..., Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)))
tup = (tup..., Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)))
tup
end
@test MBF.linefilter!(MBF.fused_assemble(expr_in, :tup)) ==
MBF.linefilter!(expr_out)
@test MBF.fused_assemble(expr_in, :tup) == expr_out
```
To use `MultiBroadcastFusion`'s `@fused_flexible` macro:
This can be used to make a custom kernel fusion macro:
```
import MultiBroadcastFusion as MBF
x = rand(1);y = rand(1);z = rand(1);
MBF.@fused_flexible begin
@. x += y
@. z += y
import MultiBroadcastFusion: fused_assemble
MBF.@make_type MyFusedBroadcast
MBF.@make_fused fused_assemble MyFusedBroadcast my_fused
Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!")
x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)
# 4 reads, 2 writes
@my_fused begin
for i in 1:3
@. y1 = x1
@. y2 = x1
end
end
```
Also see [`fused_direct`](@ref)
"""
function fused_pairs_flexible(expr::Expr, sym::Symbol)
check_restrictions_flexible(expr)
e = transform_flex(expr, sym)
fused_assemble(expr::Expr) = fused_assemble(expr, gensym())
function fused_assemble(expr::Expr, sym::Symbol)
check_restrictions_assemble(expr)
e = transform_assemble(expr, sym)
@assert e.head == :block
ex = Expr(:block, :($sym = ()), e.args..., sym)
# Filter out LineNumberNode, as this will not be valid due to prepending `tup = ()`
linefilter!(ex)
ex
end

function check_restrictions_flexible(expr::Expr)
function check_restrictions_assemble(expr::Expr)
for arg in expr.args
arg isa LineNumberNode && continue
s_error = if arg isa QuoteNode
Expand Down
46 changes: 35 additions & 11 deletions src/collection/fused_pairs.jl → src/collection/fused_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,52 @@ function transform(e::Expr)
end

"""
fused_pairs
fused_direct(expr::Expr)
Function that fuses broadcast expressions that
are immediately one after another. For example:
Directly transforms the input expression
into a tuple of `Pair`s, containing (firsts)
the destination of broadcast expressions and
(seconds) the broadcasted objects.
For example:
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedMultiBroadcast
MBF.@make_fused fused_pairs MyFusedMultiBroadcast fused
using Test
expr_in = quote
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
end
expr_out = :(tuple(
Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)),
Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)),
))
@test MBF.fused_direct(expr_in) == expr_out
```
To use `MultiBroadcastFusion`'s `@fused` macro:
This can be used to make a custom kernel fusion macro:
```
import MultiBroadcastFusion as MBF
x = rand(1);y = rand(1);z = rand(1);
MBF.@fused begin
@. x += y
@. z += y
import MultiBroadcastFusion: fused_direct
MBF.@make_type MyFusedBroadcast
MBF.@make_fused fused_direct MyFusedBroadcast my_fused
Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!")
x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)
# 4 reads, 2 writes
@my_fused begin
@. y1 = x1
@. y2 = x1
end
```
Also see [`fused_assemble`](@ref)
"""
function fused_pairs(expr::Expr)
function fused_direct(expr::Expr)
check_restrictions(expr)
e = transform(expr)
@assert e.head == :block
Expand Down
10 changes: 5 additions & 5 deletions src/collection/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ macro make_type(type_name)
end

"""
@make_fused fusion_type type_name fused_named
@make_fused fusion_style type_name fused_named
This macro
- Defines a type type_name
- Defines a macro, `@fused_name`, using the fusion type `fusion_type`
- Defines a macro, `@fused_name`, using the fusion type `fusion_style`
This allows users to flexibility
to customize their broadcast fusion.
Expand All @@ -27,7 +27,7 @@ to customize their broadcast fusion.
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedBroadcast
MBF.@make_fused MBF.fused_pairs MyFusedBroadcast my_fused
MBF.@make_fused MBF.fused_direct MyFusedBroadcast my_fused
Base.copyto!(fmb::MyFusedBroadcast) = println("You're ready to fuse!")
Expand All @@ -42,12 +42,12 @@ y2 = rand(3,3)
end
```
"""
macro make_fused(fusion_type, type_name, fused_name)
macro make_fused(fusion_style, type_name, fused_name)
t = esc(type_name)
f = esc(fused_name)
return quote
macro $f(expr)
_pairs = esc($(fusion_type)(expr))
_pairs = esc($(fusion_style)(expr))
t = $t
quote
Base.copyto!($t($_pairs))
Expand Down
4 changes: 2 additions & 2 deletions src/execution/fused_kernels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@make_type FusedMultiBroadcast
@make_fused fused_pairs FusedMultiBroadcast fused
@make_fused fused_pairs_flexible FusedMultiBroadcast fused_flexible
@make_fused fused_direct FusedMultiBroadcast fused
@make_fused fused_assemble FusedMultiBroadcast fused_assemble

struct CPU end
struct GPU end
Expand Down
2 changes: 1 addition & 1 deletion test/collection/expr_code_lowered_single_expression.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=
using Revise; include(joinpath("test", "expr_code_lowered_single_expression.jl"))
using Revise; include(joinpath("test", "collection", "expr_code_lowered_single_expression.jl"))
=#
using Test
import MultiBroadcastFusion as MBF
Expand Down
20 changes: 10 additions & 10 deletions test/collection/expr_errors_and_edge_cases.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=
using Revise; include(joinpath("test", "expr_errors_and_edge_cases.jl"))
using Revise; include(joinpath("test", "collection" "expr_errors_and_edge_cases.jl"))
=#
using Test
import MultiBroadcastFusion as MBF
Expand All @@ -21,7 +21,7 @@ import MultiBroadcastFusion as MBF
end
@. y1 = x1 + x2 + x3 + x4
end
@test_throws ErrorException("Loops are not allowed inside fused blocks") MBF.fused_pairs(
@test_throws ErrorException("Loops are not allowed inside fused blocks") MBF.fused_direct(
expr_in,
)
end
Expand All @@ -47,7 +47,7 @@ struct Foo end
end
@test_throws ErrorException(
"If-statements are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

bar() = nothing
Expand All @@ -61,7 +61,7 @@ bar() = nothing
end
@test_throws ErrorException(
"Function calls are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

@testset "Non-broadcast variable assignments" begin
Expand All @@ -74,7 +74,7 @@ end
end
@test_throws ErrorException(
"Non-broadcast assignments are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

@testset "No let-blocks" begin
Expand All @@ -87,7 +87,7 @@ end
end
@test_throws ErrorException(
"Let-blocks are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

@testset "Dangling symbols" begin
Expand All @@ -99,7 +99,7 @@ end
end
@test_throws ErrorException(
"Dangling symbols are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
) MBF.fused_direct(expr_in)
end

@testset "quote" begin
Expand All @@ -110,7 +110,7 @@ end
quote end
@. y1 = x1 + x2 + x3 + x4
end
@test_throws ErrorException("Quotes are not allowed inside fused blocks") MBF.fused_pairs(
@test_throws ErrorException("Quotes are not allowed inside fused blocks") MBF.fused_direct(
expr_in,
)
end
Expand All @@ -127,10 +127,10 @@ end
Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)),
Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)),
))
@test MBF.fused_pairs(expr_in) == expr_out
@test MBF.fused_direct(expr_in) == expr_out
end

@testset "Empty" begin
expr_in = quote end
@test MBF.fused_pairs(expr_in) == :(tuple())
@test MBF.fused_direct(expr_in) == :(tuple())
end
Loading

0 comments on commit 70901bc

Please sign in to comment.